O StableHLO é um conjunto de operações para operações de alto nível (HLO) em modelos de machine learning (ML). O StableHLO funciona como uma camada de portabilidade entre diferentes frameworks e compiladores de ML: os frameworks que produzem programas StableHLO são compatíveis com os compiladores que consomem esses programas.
Nosso objetivo é simplificar e acelerar o desenvolvimento de ML criando mais interoperabilidade entre vários frameworks de ML (como TensorFlow, JAX e PyTorch) e compiladores de ML (como XLA e IREE). Para isso, este documento fornece uma especificação para a linguagem de programação StableHLO.
Esta especificação contém três seções principais. Primeiro, a seção Programas descreve a estrutura dos programas StableHLO, que consistem em funções StableHLO, que consistem em operações StableHLO. Nessa estrutura, a seção Ops especifica a semântica de operações individuais. A seção Execução fornece semântica para todas essas operações executadas juntas em um programa. Por fim, a seção Notação aborda a notação usada em toda a especificação.
Para conferir a especificação de uma versão anterior do StableHLO, abra o repositório na versão marcada de interesse. Por exemplo, a especificação do StableHLO v0.19.0. Para conferir as mudanças que ocorreram em cada aumento de versão secundária do StableHLO, consulte o registro de versão em VhloDialect.td.
Programas
Program ::= {Func}
Os programas do StableHLO consistem em um número arbitrário de funções do StableHLO.
Confira abaixo um exemplo de programa com uma função @main que tem três entradas (%image, %weights e %bias) e uma saída. O corpo da função tem seis operações.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funções
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
As funções StableHLO (também chamadas de funções nomeadas) têm um identificador, entradas/saídas e um corpo. No futuro, planejamos introduzir metadados adicionais para funções e alcançar uma compatibilidade melhor com HLO (#425, #626, #740, #744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Os identificadores StableHLO são semelhantes aos identificadores em muitas linguagens de programação, com duas peculiaridades: 1) todos os identificadores têm sigilos que distinguem diferentes tipos de identificadores; 2) os identificadores de valor podem ser completamente numéricos para simplificar a geração de programas StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Os tipos do StableHLO são categorizados em tipos de valor (também chamados de tipos de primeira classe), que representam valores do StableHLO, e tipos não de valor, que descrevem outros elementos do programa. Os tipos do StableHLO são semelhantes aos de muitas linguagens de programação. A principal peculiaridade é a natureza específica do domínio do StableHLO, que resulta em alguns resultados incomuns (por exemplo, tipos escalares não são tipos de valor).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Os tipos de tensor representam tensores, ou seja, matrizes multidimensionais. Eles têm um formato e um tipo de elemento, em que um formato representa tamanhos de dimensão não negativos ou desconhecidos na ordem crescente das dimensões correspondentes (também chamadas de eixos) numeradas de 0 a R-1. O número de dimensões R é chamado de cardinalidade. Por exemplo, tensor<2x3xf32> é um tipo de tensor com formato 2x3 e tipo de elemento f32. Ela tem duas dimensões (ou seja, dois eixos): a 0ª e a 1ª, cujos tamanhos são 2 e 3. A classificação é 2.
As dimensões podem ser parcialmente ou totalmente desconhecidas (dinâmicas). Por exemplo, tensor<?x2xf64> é parcialmente desconhecida e tensor<?x?xf64> é totalmente desconhecida. Os tamanhos de dimensões dinâmicas são representados usando um ?. Não é possível remover a classificação das formas.
No futuro, planejamos estender os tipos de tensor além dos tamanhos de dimensão e tipos de elementos, por exemplo, para incluir layouts (#629) e esparsidade (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nome | Tipo | Restrições |
|---|---|---|
storage_type |
tipo inteiro | (C1-C3), (C8) |
storage_min |
constante de número inteiro | (C1), (C3), (C7) |
storage_max |
constante de número inteiro | (C2), (C3), (C7) |
expressed_type |
tipo de ponto flutuante | (C4) |
quantization_dimension |
constante inteira opcional | (C10-C12) |
scales |
número variádico de constantes de ponto flutuante | (C4-C6), (C9), (C10), (C13) |
zero_points |
número variádico de constantes inteiras | (C7-C9) |
Tipos de elementos quantizados representam valores inteiros de um tipo de armazenamento no intervalo de storage_min a storage_max (inclusive) que correspondem a valores de ponto flutuante de um tipo expresso. Para um determinado valor inteiro i, o valor de ponto flutuante correspondente f pode ser calculado como f = (i - zero_point) * scale, em que scale e zero_point são chamados de parâmetros de quantização. Os storage_min e storage_max são opcionais na gramática, mas têm valores padrão de min_value(storage_type) e max_value(storage_type), respectivamente. Os tipos de elementos quantizados têm as seguintes restrições:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Se
is_empty(quantization_dimension), entãosize(scales) = 1. - (C11)
0 <= quantization_dimension.
No momento, QuantizationScale é uma constante de ponto flutuante, mas há um grande interesse em escalas baseadas em números inteiros, representadas com multiplicadores e deslocamentos. Estamos planejando explorar isso em breve
(#1404).
Há uma discussão em andamento sobre a semântica de QuantizationZeroPoint,
incluindo o tipo, os valores e se pode haver apenas um ou
potencialmente vários pontos zero em um tipo de tensor quantizado. Com base nos resultados dessa discussão, a especificação sobre pontos zero pode mudar no futuro (#1405).
Outra discussão em andamento envolve a semântica de QuantizationStorageMin
e QuantizationStorageMax para determinar se alguma restrição deve ser
imposta a esses valores e aos valores de tensores quantizados
(#1406).
Por fim, planejamos explorar a representação de escalas desconhecidas e pontos zero, assim como planejamos explorar a representação de tamanhos de dimensões desconhecidas (#1407).
Os tipos de tensores quantizados representam tensores com elementos quantizados. Esses tensores são exatamente iguais aos tensores regulares, exceto que os elementos deles têm tipos de elementos quantizados, em vez de tipos de elementos regulares.
Em tensores quantizados, a quantização pode ser por tensor, ou seja, ter um scale e um zero_point para todo o tensor, ou por eixo, ou seja, ter vários scales e zero_points, um par por fração de uma dimensão específica quantization_dimension. Mais formalmente, em um tensor t com quantização por eixo, há dim(t, quantization_dimension) slices de quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] etc. Todos os elementos no iº slice usam scales[i] e zero_points[i] como parâmetros de quantização. Os tipos de tensores quantizados têm as seguintes restrições:
- Para quantização por tensor:
- Sem outras restrições.
- Para quantização por eixo:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
Os tipos de token representam tokens, ou seja, valores opacos produzidos e consumidos por algumas operações. Os tokens são usados para impor a ordem de execução das operações, conforme descrito na seção Execução.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Tipos de buffer representam buffers. Por exemplo, no XLA, os buffers são matrizes multidimensionais com armazenamento consistente. Assim como os tipos de tensor, os tipos de buffer têm uma forma e um tipo de elemento, em que uma forma representa tamanhos de dimensão não negativos ou desconhecidos na ordem crescente das dimensões correspondentes (também chamadas de eixos) numeradas de 0 a R-1. O número de dimensões R é chamado de classificação. Por exemplo, memref<2x3xf32> é um tipo de buffer com formato 2x3 e tipo de elemento f32. Ela tem duas dimensões (ou seja, dois eixos): a dimensão 0 e a dimensão 1, cujos tamanhos são 2 e 3. A classificação é 2.
Os buffers podem ser alocados usando um custom_call para CreateBuffer ou Pin e desalocados por um custom_call para Unpin. Somente operações custom_call podem ler e gravar o conteúdo dentro dos buffers. Consulte custom_call para mais detalhes.
Tipos de tupla representam tuplas, ou seja, listas heterogêneas. As tuplas são um recurso legado que existe apenas para compatibilidade com HLO. Em HLO, as tuplas são usadas para representar entradas e saídas variadas. No StableHLO, entradas e saídas variádicas são compatíveis de forma nativa, e o único uso de tuplas é para representar de forma abrangente a ABI do HLO, em que, por exemplo, T, tuple<T> e tuple<tuple<T>> podem ser materialmente diferentes dependendo de uma implementação específica. No futuro, planejamos fazer mudanças na ABI do HLO, o que pode permitir remover tipos de tupla do StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Os tipos de elementos representam elementos de tipos de tensor. Ao contrário de muitas linguagens de programação, esses tipos não são de primeira classe no StableHLO. Isso significa que
os programas StableHLO não podem representar diretamente valores desses tipos. Como resultado,
é idiomático representar valores escalares do tipo T com valores de tensor
de dimensão zero do tipo tensor<T>.
- O tipo booleano representa os valores booleanos
trueefalse. - Tipos de números inteiros podem ser assinados (
si) ou não assinados (ui) e têm uma das larguras de bits compatíveis (2,4,8,16,32ou64). Os tipossiNassinados representam valores inteiros de-2^(N-1)a2^(N-1)-1, inclusive, e os tiposuiNnão assinados representam valores inteiros de0a2^N-1, inclusive. - Os tipos de ponto flutuante podem ser um dos seguintes:
- Números de ponto flutuante de 8 bits
f8E3M4,f8E4M3ef8E5M2seguindo as convenções IEEE-754. - Tipos
f8E4M3FNef8E5M2correspondentes, respectivamente, às codificaçõesE4M3eE5M2do formato FP8 descritas em Formatos FP8 para aprendizado profundo. - Tipos
f8E4M3FNUZef8E5M2FNUZcorrespondentes às codificaçõesE4M3eE5M2dos formatos FP8 descritos em Formatos numéricos de 8 bits para redes neurais profundas. - Tipo
f8E4M3B11FNUZcorrespondente à codificaçãoE4M3dos formatos FP8 descritos em Treinamento e inferência de ponto flutuante híbrido de 8 bits (HFP8) para redes neurais profundas. - Tipo
bf16correspondente ao formatobfloat16descrito em BFloat16: o segredo do alto desempenho em Cloud TPUs. - Os tipos
f16,f32ef64correspondem, respectivamente, aos formatosbinary16("meia precisão"),binary32("precisão simples") ebinary64("precisão dupla") descritos no padrão IEEE 754. - O tipo
tf32corresponde ao formato TensorFloat32 e tem suporte limitado no StableHLO. f4E2M1FN,f6E2M3FN,f6E3M2FNef8E8M0FNUMX (microescalonamento) descritos na Especificação de formatos de microescalonamento do OCP.
- Números de ponto flutuante de 8 bits
- Tipos complexos representam valores complexos que têm uma parte real e uma parte imaginária do mesmo tipo de elemento. Os tipos complexos compatíveis são
complex<f32>(ambas as partes são do tipof32) ecomplex<f64>(ambas as partes são do tipof64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Os tipos de função representam funções nomeadas e anônimas. Elas têm tipos de entrada (a lista de tipos no lado esquerdo de ->) e tipos de saída (a lista de tipos no lado direito de ->). Em muitas linguagens de programação, os tipos de função são de primeira classe, mas não no StableHLO.
StringType ::= 'string'
O tipo de string representa sequências de bytes. Ao contrário de muitas linguagens de programação, o tipo de string não é de primeira classe no StableHLO e é usado apenas para especificar metadados estáticos para elementos de programa.
Operações
As operações StableHLO (também chamadas de ops) representam um conjunto fechado de operações de alto nível em modelos de machine learning. Conforme discutido acima, a sintaxe do StableHLO é muito inspirada no MLIR, que não é necessariamente a alternativa mais ergonômica, mas é provavelmente a mais adequada para o objetivo do StableHLO de criar mais interoperabilidade entre frameworks e compiladores de ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
As operações do StableHLO (também chamadas de ops) têm um nome, entradas/saídas e uma assinatura. O nome consiste no prefixo stablehlo. e
um mnemônico que identifica exclusivamente uma das operações compatíveis. Confira abaixo uma lista completa de todas as operações compatíveis.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
As operações consomem entradas e produzem saídas. As entradas são categorizadas em valores de entrada (calculados durante a execução), funções de entrada (fornecidas estaticamente, porque no StableHLO as funções não são valores de primeira classe) e atributos de entrada (também fornecidos estaticamente). O tipo de entradas e saídas consumidas e produzidas por uma operação depende do mnemônico dela. Por exemplo, a operação add
consome dois valores de entrada e produz um valor de saída. Em comparação, a operação select_and_scatter consome três valores de entrada, duas funções de entrada e três atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
As funções de entrada (também chamadas de funções anônimas) são muito parecidas com as funções nomeadas, exceto que: 1) elas não têm um identificador (daí o nome "anônimas"); 2) elas não declaram tipos de saída (os tipos de saída são inferidos da operação return na função).
A sintaxe das funções de entrada inclui uma parte não usada no momento (consulte a produção Unused acima), que está lá para compatibilidade com MLIR. No MLIR, existe um conceito mais geral de "regiões", que podem ter vários "blocos" de operações conectados por operações de salto. Esses blocos têm IDs que correspondem
à produção Unused, para que possam ser diferenciados.
O StableHLO não tem operações de salto, então a parte correspondente da sintaxe MLIR não é usada, mas ainda está lá.
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Os atributos de entrada têm um nome e um valor que é uma das constantes compatíveis. Eles são a principal maneira de especificar metadados estáticos para elementos
de programa. Por exemplo, a operação concatenate usa o atributo dimension para especificar a dimensão ao longo da qual os valores de entrada são concatenados. Da mesma forma, a operação slice usa vários atributos, como start_indices e limit_indices, para especificar os limites usados para dividir o valor de entrada.
No momento, os programas StableHLO em uso às vezes contêm atributos que não são descritos neste documento. No futuro, planejamos incorporar esses atributos ao conjunto de operações do StableHLO ou proibir que eles apareçam em programas do StableHLO. Enquanto isso, confira a lista de atributos:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- Metadados de local (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
A assinatura da operação consiste nos tipos de todos os valores de entrada (a lista de tipos no lado esquerdo de ->) e nos tipos de todos os valores de saída (a lista de tipos no lado direito de ->). Falando estritamente, os tipos de entrada são redundantes, e os tipos de saída também são quase sempre redundantes, porque, para a maioria das operações do StableHLO, os tipos de saída podem ser inferidos das entradas. No entanto, a assinatura da operação faz parte da sintaxe do StableHLO para compatibilidade com MLIR.
Confira abaixo um exemplo de operação cujo mnemônico é select_and_scatter. Ele consome três valores de entrada (%operand, %source e %init_value), duas funções de entrada e três atributos de entrada (window_dimensions, window_strides e padding). Observe como a assinatura da operação inclui apenas os tipos dos valores de entrada, mas não os tipos de funções e atributos de entrada, que são fornecidos inline.
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
As constantes do StableHLO têm um literal e um tipo que representam juntos um valor do StableHLO. Em geral, o tipo faz parte da sintaxe constante, exceto quando não há ambiguidade. Por exemplo, uma constante booleana tem o tipo i1, enquanto uma constante de número inteiro pode ter vários tipos possíveis.
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
As constantes booleanas representam os valores booleanos true e false. As constantes booleanas têm o tipo i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Constantes de números inteiros representam valores inteiros usando strings que usam notação decimal ou hexadecimal. Outras bases, como binária ou octal, não são compatíveis. As constantes inteiras têm as seguintes restrições:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
As constantes de ponto flutuante representam valores de ponto flutuante usando strings que usam notação decimal ou científica. Além disso, a notação hexadecimal pode ser usada para especificar diretamente os bits subjacentes no formato de ponto flutuante do tipo correspondente. As constantes de ponto flutuante têm as seguintes restrições:
- (C1) Se uma notação não hexadecimal for usada,
is_wellformed(float_literal, float_type). - (C2) Se a notação hexadecimal for usada,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Constantes complexas representam valores complexos usando listas de uma parte real (primeiro) e uma parte imaginária (segundo). Por exemplo, (1.0, 0.0) : complex<f32> representa 1.0 + 0.0i e (0.0, 1.0) : complex<f32> representa 0.0 + 1.0i. A ordem em que essas partes são armazenadas na memória é definida pela implementação. As constantes complexas têm as seguintes restrições:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
As constantes de tensor representam valores de tensor usando listas aninhadas especificadas pela notação do NumPy. Por exemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> representa um valor de tensor com o seguinte mapeamento de índices para elementos: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. A ordem em que esses elementos são armazenados na memória é definida pela implementação. As constantes de tensor têm as seguintes restrições:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), em que:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), em que:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- Caso contrário,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
As constantes de tensor quantizado representam valores de tensor quantizado usando a mesma notação das constantes de tensor, com elementos especificados como constantes do tipo de armazenamento. As constantes de tensor quantizadas têm as seguintes restrições:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Os literais de string consistem em bytes especificados usando caracteres ASCII e sequências de escape. Eles são independentes de codificação, então a interpretação desses bytes é definida pela implementação. Os literais de string têm o tipo string.
Operações
abs
Semântica
Executa uma operação abs elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros com sinal: módulo inteiro.
- Para pontos flutuantes:
absdo IEEE-754. - Para números complexos: módulo complexo.
- Para tipos quantizados:
dequantize_op_quantize(abs, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de número inteiro com sinal, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1-C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de número inteiro com sinal ou de ponto flutuante ou tensor quantizado por tensor | (C1-C2) |
Restrições
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)é definido como:complex_element_type(element_type(operand))seis_complex(operand).baseline_element_type(operand)caso contrário.
Exemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
adicionar
Semântica
Realiza a adição de elementos de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: adição de números inteiros.
- Para pontos flutuantes:
additiondo IEEE-754. - Para números complexos: adição complexa.
- Para tipos quantizados:
dequantize_op_quantize(add, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado | (C1 a C6) |
| (I2) | rhs |
tensor ou tensor quantizado | (C1-C5), (C7) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1 a C7) |
Restrições
- Se a operação usar tensores não quantizados:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Se a operação usar tensores quantizados:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Se
is_per_axis_quantized(lhs), entãoquantization_dimension(lhs) = quantization_dimension(result). - (C7) Se
is_per_axis_quantized(rhs), entãoquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
Semântica
Garante que as operações que produzem o inputs sejam executadas antes de qualquer
operação que dependa de result. A execução dessa operação não faz nada. Ela só existe para estabelecer dependências de dados de result para inputs.
Entradas
| Rótulo | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
número variádico de token |
Saídas
| Nome | Tipo |
|---|---|
result |
token |
Exemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
Semântica
Em cada grupo de processos na grade de processos do StableHLO, concatena os valores dos tensores operands de cada processo ao longo de all_gather_dim e produz tensores results.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(replica_groups)sechannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sechannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sechannel_id > 0 and use_global_device_ids = true.
Depois, em cada process_group:
operands...@receiver = [operand@sender for sender in process_group]para todos osreceiveremprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)para todos osprocessemprocess_group.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operands |
número variádico de tensores ou tensores quantizados por tensor | (C1), (C6) |
| (I2) | all_gather_dim |
constante do tipo si64 |
(C1), (C6) |
| (I3) | replica_groups |
Constante de tensor bidimensional do tipo si64 |
(C2-C4) |
| (I4) | channel_id |
constante do tipo si64 |
(C5) |
| (I5) | use_global_device_ids |
constante do tipo i1 |
(C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C6) |
Restrições
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)é definido como:num_replicassecross_replicafor usado.num_replicassecross_replica_and_partitionfor usado.num_processesseflattened_idsfor usado.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Se
use_global_device_ids = true, entãochannel_id > 0. - (C6)
type(results...) = type(operands...), exceto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semântica
Em cada grupo de processos na grade de processos do StableHLO, aplica uma função de redução computation aos valores dos tensores operands de cada processo e produz tensores results.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(replica_groups)sechannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sechannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sechannel_id > 0 and use_global_device_ids = true.
Depois, em cada process_group:
results...@process[result_index] = exec(schedule)para alguma árvore bináriascheduleem que:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
scheduleé uma árvore binária definida pela implementação cuja travessia em ordem éto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operands |
número variádico de tensores ou tensores quantizados por tensor | (C5), (C6) |
| (I2) | replica_groups |
número variádico de constantes de tensor unidimensionais do tipo si64 |
(C1-C3) |
| (I3) | channel_id |
constante do tipo si64 |
(C4) |
| (I4) | use_global_device_ids |
constante do tipo i1 |
(C4) |
| (I5) | computation |
função | (C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C6-C7) |
Restrições
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)é definido como:num_replicassecross_replicafor usado.num_replicassecross_replica_and_partitionfor usado.num_processesseflattened_idsfor usado.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Se
use_global_device_ids = true, entãochannel_id > 0. - (C5)
computationtem o tipo(tensor<E>, tensor<E>) -> (tensor<E>)em queis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semântica
Em cada grupo de processos na grade de processos do StableHLO, divide os valores dos tensores operands ao longo de split_dimension em partes, dispersa as partes divididas entre os processos, concatena as partes dispersas ao longo de concat_dimension e produz tensores results.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(replica_groups)sechannel_id <= 0.cross_partition(replica_groups)sechannel_id > 0.
Depois, em cada process_group:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)para todos ossenderemprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]em quereceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operands |
número variádico de tensores ou tensores quantizados por tensor | (C1-C3), (C9) |
| (I2) | split_dimension |
constante do tipo si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
constante do tipo si64 |
(C3), (C9) |
| (I4) | split_count |
constante do tipo si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
Constante de tensor bidimensional do tipo si64 |
(C5-C8) |
| (I6) | channel_id |
constante do tipo si64 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C9) |
Restrições
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)é definido como:num_replicassecross_replicafor usado.num_partitionssecross_partitionfor usado.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...), exceto sesplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
e
Semântica
Executa a operação AND bit a bit de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: AND bit a bit.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo booleano ou inteiro | (C1) |
| (I2) | rhs |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
Semântica
Executa a operação atan2 com elementos nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
atan2do IEEE-754. - Para números complexos: atan2 complexo.
- Para tipos quantizados:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semântica
Calcula gradientes de várias entradas de batch_norm_training usando backpropagation de grad_output e produz tensores grad_operand, grad_scale e grad_offset. De maneira mais formal, essa operação pode ser expressa como uma decomposição em operações StableHLO atuais usando a sintaxe do Python da seguinte maneira:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para tipos quantizados, realiza
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1-C3), (C5) |
| (I2) | scale |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4), (C5) |
| (I3) | mean |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4) |
| (I4) | variance |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4) |
| (I5) | grad_output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2), (C3) |
| (I6) | epsilon |
constante do tipo f32 |
|
| (I7) | feature_index |
constante do tipo si64 |
(C1), (C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
grad_operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2), (C3) |
grad_scale |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4) |
grad_offset |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4) |
Restrições
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scaleegrad_offsettêm o mesmobaseline_element_type. - (C3)
operand,grad_outputegrad_operandtêm o mesmo formato. - (C4)
scale,mean,variance,grad_scaleegrad_offsettêm a mesma forma. - (C5)
size(scale) = dim(operand, feature_index).
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semântica
Normaliza o tensor operand em todas as dimensões, exceto a feature_index, e produz um tensor result. De maneira mais formal, essa operação pode ser expressa como uma decomposição em operações StableHLO atuais usando a sintaxe do Python da seguinte maneira:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para tipos quantizados, realiza
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1 a C7) |
| (I2) | scale |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C3) |
| (I3) | offset |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C4) |
| (I4) | mean |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C5) |
| (I5) | variance |
Tensor unidimensional de ponto flutuante ou tipo quantizado por tensor | (C2), (C6) |
| (I6) | epsilon |
constante do tipo f32 |
|
| (I7) | feature_index |
constante do tipo si64 |
(C1), (C3-C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2), (C7) |
Restrições
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceeresulttêm o mesmobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semântica
Calcula a média e a variância em todas as dimensões, exceto a feature_index, e normaliza o tensor operand, produzindo os tensores output, batch_mean e batch_var. De maneira mais formal, essa operação pode ser expressa como uma decomposição em operações StableHLO atuais usando a sintaxe do Python da seguinte maneira:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para tipos quantizados, realiza
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
| (I2) | scale |
Tensor unidimensional de ponto flutuante ou quantizado por tensor | (C2), (C3) |
| (I3) | offset |
Tensor unidimensional de ponto flutuante ou quantizado por tensor | (C2), (C4) |
| (I4) | epsilon |
constante do tipo f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
constante do tipo si64 |
(C1), (C3-C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C7) |
batch_mean |
Tensor unidimensional de ponto flutuante ou quantizado por tensor | (C2), (C5) |
batch_var |
Tensor unidimensional de ponto flutuante ou quantizado por tensor | (C2), (C6) |
Restrições
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_vareoutputtêm o mesmobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semântica
Realiza uma operação bitcast no tensor operand e produz um tensor result em que os bits de todo o tensor operand são reinterpretados usando o tipo do tensor result.
Mais formalmente, considerando E = element_type(operand), E' = element_type(result) e R = rank(operand):
- Se
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Se
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Se
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits retorna a representação na memória de um determinado valor, e o comportamento dele é definido pela implementação porque a representação exata de tensores e tipos de elementos também é definida pela implementação.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1-C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1-C2) |
Restrições
- (C1) Considerando
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result)eR = rank(operand):- Se
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Se
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)para todos os0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- Se
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)para todos os0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- Se
- (C2) Se
is_complex(operand) or is_complex(result), entãois_complex(operand) and is_complex(result).
Exemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semântica
Expande as dimensões e/ou a classificação de um tensor de entrada duplicando os dados no tensor operand e produz um tensor result. De maneira mais formal, result[result_index] = operand[operand_index], em que para todo d em axes(operand):
operand_index[d] = 0sedim(operand, d) = 1.operand_index[d] = result_index[broadcast_dimensions[d]]caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C2-C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1), (C3), (C5-C6) |
Restrições
- (C1)
element_type(result)é dado por:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), exceto quequantization_dimension(operand),scales(operand)ezero_points(operand)podem ser diferentes dequantization_dimension(result),scales(result)ezero_points(result), respectivamente, caso contrário.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Para todos os
demaxes(operand):dim(operand, d) = 1oudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Se
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Se
dim(operand, quantization_dimension(operand)) = 1, entãoscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Exemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semântica
Produz a saída da execução de exatamente uma função de branches, dependendo do valor de index. De maneira mais formal, result = selected_branch()
em que:
selected_branch = branches[index]se0 <= index < size(branches).selected_branch = branches[-1]caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | index |
Tensor de dimensão 0 do tipo si32 |
|
| (I2) | branches |
número variádico de funções | (C1 a C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores, tensores quantizados ou tokens | (C4) |
Restrições
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Exemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semântica
Executa a operação de raiz cúbica elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
rootn(x, 3)do IEEE-754. - Para números complexos: raiz cúbica complexa.
- Para tipos quantizados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semântica
Realiza o teto de elemento a elemento do tensor operand e produz um tensor result.
Implementa a operação roundToIntegralTowardPositive da especificação IEEE-754. Para tipos quantizados, realiza
dequantize_op_quantize(ceil, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semântica
Calcula a decomposição de Cholesky de um lote de matrizes.
Mais formalmente, para todo i em index_space(result), result[i0, ..., iR-3, :, :] é uma decomposição de Cholesky de a[i0, ..., iR-3, :, :], na forma de uma matriz triangular inferior (se lower for true) ou triangular superior (se lower for false).
Os valores de saída no triângulo oposto, ou seja, o triângulo estrito superior ou inferior, respectivamente, são definidos pela implementação.
Se houver i em que a matriz de entrada não seja uma matriz hermitiana positiva definida, o comportamento será indefinido.
Para tipos quantizados, realiza
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | a |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1-C3) |
| (I2) | lower |
constante do tipo i1 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Exemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limitar
Semântica
Fixa cada elemento do tensor operand entre um valor mínimo e máximo e produz um tensor result. Mais formalmente, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element), em que min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Para tipos quantizados, realiza dequantize_op_quantize(clamp, min, operand, max, type(result)).
Impor uma ordenação em números complexos envolve semântica surpreendente. Por isso, no futuro, planejamos remover o suporte a números complexos para essa operação (#560).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | min |
tensor ou tensor quantizado por tensor | (C1), (C3) |
| (I2) | operand |
tensor ou tensor quantizado por tensor | (C1 a C4) |
| (I3) | max |
tensor ou tensor quantizado por tensor | (C2), (C3) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C4) |
Restrições
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Exemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
Semântica
Em cada grupo de processos na grade de processos do StableHLO, envie o valor do tensor operand do processo de origem para os processos de destino e produza um tensor result.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(replica_groups)sechannel_id <= 0.cross_partition(replica_groups)sechannel_id > 0.
Depois disso, result@process é dado por:
operand@process_groups[i, 0]se existir umipara que o processo esteja emprocess_groups[i].broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C3) |
| (I2) | replica_groups |
número variádico de constantes de tensor unidimensionais do tipo si64 |
(C1), (C2) |
| (I3) | channel_id |
constante do tipo si64 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3) |
Restrições
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < Nem queNé definido como:num_replicassecross_replicafor usado.num_partitionssecross_partitionfor usado.
- (C3)
type(result) = type(operand).
Exemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semântica
Em cada grupo de processos na grade de processos do StableHLO, envia o valor do tensor operand do processo de origem para o de destino e produz um tensor result.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(source_target_pairs)sechannel_id <= 0.cross_partition(source_target_pairs)sechannel_id > 0.
Depois disso, result@process é dado por:
operand@process_groups[i, 0], se houver umiem queprocess_groups[i, 1] = process.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C5) |
| (I2) | source_target_pairs |
Constante de tensor bidimensional do tipo si64 |
(C1 a C4) |
| (I3) | channel_id |
constante do tipo si64 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, em queNé definido como:num_replicassecross_replicafor usado.num_partitionssecross_partitionfor usado.
- (C5)
type(result) = type(operand).
Exemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Semântica
Realiza uma comparação elemento a elemento dos tensores lhs e rhs de acordo com comparison_direction e compare_type e produz um tensor result.
Os valores de comparison_direction e compare_type têm a seguinte semântica:
Para tipos de elementos booleanos e inteiros:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
Para tipos de elementos de ponto flutuante com compare_type = FLOAT, a operação implementa as seguintes operações IEEE-754:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
Para tipos de elementos de ponto flutuante com compare_type = TOTALORDER, a operação usa a combinação de operações totalOrder e compareQuietEqual do IEEE-754.
Para tipos de elementos complexos, a comparação lexicográfica de pares (real, imag) é realizada usando o comparison_direction e o compare_type fornecidos.
Impor uma ordenação em números complexos envolve semântica surpreendente. Por isso, no futuro, planejamos remover o suporte a números complexos quando comparison_direction for GE, GT, LE ou LT (#560).
Para tipos quantizados, realiza dequantize_compare(lhs, rhs,
comparison_direction).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1-C3) |
| (I2) | rhs |
tensor ou tensor quantizado por tensor | (C1-C2) |
| (I3) | comparison_direction |
enum de EQ, NE, GE, GT, LE e LT |
|
| (I4) | compare_type |
enum de FLOAT, TOTALORDER, SIGNED e UNSIGNED |
(C3) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor do tipo booleano | (C2) |
Restrições
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typeé definido como:SIGNEDseis_signed_integer(element_type(lhs)).UNSIGNEDseis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOATouTOTALORDERseis_float(element_type(lhs)).FLOATseis_complex(element_type(lhs)).
Exemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
complexo
Semântica
Realiza a conversão de elemento por elemento para um valor complexo de um par de valores reais e imaginários, lhs e rhs, e produz um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor do tipo f32 ou f64 |
(C1-C3) |
| (I2) | rhs |
tensor do tipo f32 ou f64 |
(C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo complexo | (C2), (C3) |
Restrições
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)tem o tipocomplex<E>em queE = element_type(lhs).
Exemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
composto
Semântica
Encapsula uma operação composta por outras operações do StableHLO, usando inputs e composite_attributes e produzindo results. A semântica da operação é implementada pelo atributo decomposition. A operação composite pode ser substituída pela decomposição dela sem mudar a semântica do programa. Nos casos em que a inclusão inline da decomposição não fornece a mesma semântica de operação, prefira usar custom_call.
O campo version (padrão 0) é usado para indicar quando a semântica de um composto muda.
Entradas
| Rótulo | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
número variádico de valores |
| (I2) | name |
constante do tipo string |
| (I3) | composite_attributes |
dicionário de atributos |
| (I4) | decomposition |
constante do tipo string |
| (I5) | version |
constante do tipo si32 |
Saídas
| Nome | Tipo |
|---|---|
results |
número variádico de valores |
Restrições
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Exemplos
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenate
Semântica
Concatena inputs ao longo da dimensão dimension na mesma ordem dos argumentos fornecidos e produz um tensor result. De maneira mais formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], em que:
id = d0 + ... + dk-1 + kd.dé igual adimension, ed0, ... são tamanhos dadª dimensão deinputs.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1 a C6) |
| (I2) | dimension |
constante do tipo si64 |
(C2), (C4), (C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5-C6) |
Restrições
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...))exceto paradim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0]), exceto:dim(result, dimension) = dim(inputs[0], dimension) + ....
Exemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semântica
Produz um tensor output de uma constante value.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | value |
constante | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
output |
tensor ou tensor quantizado | (C1) |
Restrições
- (C1)
type(value) = type(output).
Exemplos
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
fazer uma conversão
Semântica
Executa uma conversão elemento a elemento de um tipo de elemento para outro no tensor operand e produz um tensor result.
Para conversões boolean-to-any-supported-type, o valor false é convertido em zero, e o valor true é convertido em um. Para conversões de any-supported-type-to-boolean, um valor zero é convertido em false, e valores diferentes de zero são convertidos em true. Confira abaixo como isso funciona para tipos complexos.
Para conversões envolvendo inteiro para inteiro, inteiro para ponto flutuante ou ponto flutuante para ponto flutuante, se o valor de origem puder ser representado exatamente no tipo de destino, o valor resultante será essa representação exata. Caso contrário, o comportamento será definido (#180).
Para conversões que envolvem floating-point-to-integer, a parte fracionária é truncada. Se o valor truncado não puder ser representado no tipo de destino, o comportamento será definido posteriormente (#180).
A conversão envolvendo complexo para complexo segue o mesmo comportamento das conversões de ponto flutuante para ponto flutuante para converter partes reais e imaginárias.
Para conversões complex-to-any-other-type e any-other-type-to-complex, o valor imaginário de origem é ignorado ou o valor imaginário de destino é zerado, respectivamente. A conversão da parte real segue as conversões de ponto flutuante.
Em princípio, essa operação pode expressar a desquantização (conversão de tensores quantizados em tensores regulares), a quantização (conversão de tensores regulares em tensores quantizados) e a requantização (conversão entre tensores quantizados), mas no momento temos operações dedicadas para isso: uniform_dequantize para o primeiro caso de uso e uniform_quantize para o segundo e o terceiro. No futuro, essas duas operações poderão ser combinadas em convert (#1576).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor | (C1) |
Restrições
- (C1)
shape(operand) = shape(result).
Exemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolução
Semântica
Calcula produtos escalares entre janelas de lhs e intervalos de rhs e gera result. O diagrama a seguir mostra como os elementos em result são calculados com base em lhs e rhs usando um exemplo concreto.
De maneira mais formal, considere o seguinte reformulação das entradas em termos de lhs para expressar janelas de lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Essa reformulação usa as seguintes funções auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]em quej[d] = i[permutation[d]].
Se feature_group_count = 1 e batch_group_count = 1, então para todo output_spatial_index em index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product em que:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Esse recurso parece não estar sendo usado. Por isso, planejamos removê-lo no futuro (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Se feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Se batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Para tipos quantizados, realiza dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
Para tipos quantizados híbridos, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tensor ou tensor quantizado | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
Constante de tensor unidimensional do tipo si64 |
(C2-C3), (C25) |
| (I4) | padding |
Constante de tensor bidimensional do tipo si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Constante de tensor unidimensional do tipo si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
Constante de tensor unidimensional do tipo si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
Constante de tensor unidimensional do tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
constante do tipo si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
constante do tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
constante do tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante do tipo si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
constante do tipo si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
constante do tipo si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
constante do tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante do tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
número variádico de enums de DEFAULT, HIGH e HIGHEST |
(C24) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C25-C28), (C30), (C32-34) |
Restrições
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)é definido como:dim(lhs, input_batch_dimension) / batch_group_countseresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)seresult_dim = output_feature_dimension.num_windowscaso contrário, em que:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Se a operação usar tensores não quantizados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Se a operação usar tensores quantizados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Se
is_per_axis_quantized(rhs), entãoquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Se
is_per_axis_quantized(result), entãoquantization_dimension(result) = output_feature_dimension. - Se
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Se
is_per_tensor_quantized(rhs), entãois_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Exemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosseno
Semântica
Executa a operação de cosseno elemento a elemento no tensor operand e gera um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
cosdo IEEE-754. - Para números complexos: cosseno complexo.
- Para tipos quantizados:
dequantize_op_quantize(cosine, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semântica
Realiza a contagem de elementos do número de bits zero à esquerda no tensor operand e produz um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result).
Exemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
Semântica
Encapsula uma operação call_target_name definida pela implementação que usa
inputs e called_computations e produz results. has_side_effect, backend_config e api_version podem ser usados para fornecer metadados adicionais definidos pela implementação.
No momento, essa operação contém uma coleção bastante desorganizada de metadados que reflete a evolução orgânica da operação equivalente no compilador XLA. No futuro, planejamos unificar esses metadados (#741).
Entradas
| Rótulo | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
número variádico de valores |
| (I2) | call_target_name |
constante do tipo string |
| (I3) | has_side_effect |
constante do tipo i1 |
| (I4) | backend_config |
constante do tipo string ou dicionário de atributos |
| (I5) | api_version |
constante do tipo si32 |
| (I6) | called_computations |
número variádico de constantes do tipo string |
| (I7) | output_operand_aliases |
especificar as partes de alias nas saídas e nos operandos |
Saídas
| Nome | Tipo |
|---|---|
results |
número variádico de valores |
(Suporte a GPU XLA) Destinos especiais de custom_call
Há três call_target_name especiais relacionados a tipos buffer:
CreateBuffer cria um buffer não inicializado, Pin cria um buffer inicializado
e Unpin desaloca um buffer e retorna o conteúdo do
buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
Alias
Algumas operações custom_call podem exigir uma parte nas saídas e outra nos operandos para compartilhar a mesma memória. Isso pode ser expresso por
output_operand_aliases. Uma representação de par de alias consiste em uma lista de índices de tuplas de saída que representam a parte de saída e um operand_index com uma lista de índices de tuplas de operando que representam a parte de operando. A lista de índices de tuplas de saída ou de operandos fica vazia se o tipo correspondente não for tuple e pode ser arbitrariamente longa para um tipo de tupla arbitrariamente aninhado. Isso é semelhante à representação do alias XLA.
As partes de saída e entrada em um par de alias precisam ter o mesmo tipo. Para operações
custom_call que não são chamadas para CreateBuffer, Pin e Unpin, um
operando buffer pode aparecer em no máximo um par de alias, e uma saída buffer
precisa aparecer em um par de alias.
Exemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
dividir
Semântica
Realiza a divisão elemento a elemento dos tensores dividendo lhs e divisor rhs e gera um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: divisão inteira que produz o quociente algébrico com qualquer parte fracionária descartada.
- Para pontos flutuantes:
divisiondo IEEE-754. - Para números complexos: divisão complexa.
- Para tipos quantizados:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semântica
Calcula produtos escalares entre fatias de lhs e fatias de rhs e produz um tensor result.
Mais formalmente, result[result_index] = dot_product, em que:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_indexem quesize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)esize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
Para tipos quantizados, realiza dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
Para tipos quantizados híbridos, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config controla a compensação entre velocidade e acurácia para
computações em back-ends de aceleradores. Pode ser um dos seguintes valores (no momento, a semântica desses valores de enumeração está subespecificada, mas planejamos resolver isso em #755):
DEFAULT: cálculo mais rápido, mas aproximação menos precisa do número original.HIGH: cálculo mais lento, mas aproximação mais precisa do número original.HIGHEST: cálculo mais lento, mas aproximação mais precisa do número original.
Um DotAlgorithm define as principais propriedades do algoritmo usado para implementar
a operação de ponto, que também define a precisão. Se os campos de atributo do algoritmo forem definidos, o precision_config precisará ser DEFAULT. DotAlgorithms
não têm um valor padrão, já que os parâmetros padrão são definidos
pela implementação. Assim, todos os campos do algoritmo de ponto podem ser definidos como None para especificar um algoritmo de ponto vazio, que usará o valor precision_config.
Os campos DotAlgorithm incluem:
lhs_precision_typeerhs_precision_type, as precisões para as quais os lados esquerdo e direito da operação são arredondados. Os tipos de precisão são independentes dos tipos de armazenamento das entradas e da saída.accumulation_typea precisão usada para acumulação.lhs_component_count,rhs_component_countenum_primitive_operationssão aplicados quando usamos um algoritmo que decompõe o lado esquerdo e/ou direito em vários componentes e faz várias operações de ponto "primitivas" nesses valores, geralmente para emular uma precisão maior (por exemplo, Como aproveitar o tipo de dados de inteligência artificial bfloat16 para cálculos de maior precisão: bf16_6x tf32_3x etc.). Para algoritmos sem decomposição, esses valores precisam ser definidos como1.allow_imprecise_accumulationpara especificar se o acúmulo em precisão menor é permitido para algumas etapas (por exemplo,CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Exemplos de atributos DotAlgorithm:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Cabe às implementações decidir quais combinações são compatíveis. Em geral, não há garantia de que cada algoritmo seja compatível com cada tipo de acelerador pelo consumidor do StableHLO. Se um determinado algoritmo não for compatível, um erro deverá ser gerado em vez de usar uma alternativa. A verificação do StableHLO vai fornecer a melhor verificação possível, evitando algoritmos que não são conhecidos por serem compatíveis com qualquer hardware.
Consulte xla_data.proto > Algorithm
para conferir alguns valores de algoritmo compatíveis. O tíquete nº 2483 registra o plano de criar um documento centralizado sobre algoritmos compatíveis por back-end.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tensor ou tensor quantizado | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
número variádico de enums de DEFAULT, HIGH e HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType ou TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
constante do tipo si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
constante do tipo si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
constante do tipo si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
constante do tipo bool |
(C21) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C12), (C14), (C18-C20) |
Restrições
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Se a operação usar tensores não quantizados:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Se a operação usar tensores quantizados:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Se
is_per_axis_quantized(rhs), entãoquantization_dimension(rhs)não está emrhs_contracting_dimensions. - Se
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Se
is_per_tensor_quantized(rhs), entãois_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Se
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Exemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semântica
Essa operação é funcionalmente idêntica à operação broadcast_in_dim, mas o formato do resultado é especificado dinamicamente via output_dimensions.
A operação também aceita atributos opcionais known_expanding_dimensions, known_nonexpanding_dimensions para expressar conhecimento estático sobre o comportamento de expansão das dimensões.
Se não for especificado, todas as dimensões serão consideradas expansíveis.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
Tensor unidimensional de tipo inteiro | (C7) |
| (I3) | broadcast_dimensions |
Tensor constante unidimensional do tipo inteiro | (C2-C6) |
| (I4) | known_expanding_dimensions |
Tensor constante unidimensional do tipo inteiro | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
Tensor constante unidimensional do tipo inteiro | (C8-C9) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1), (C3), (C5-C7) |
Restrições
- (C1)
element_type(result)é dado por:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), exceto quequantization_dimension(operand),scales(operand)ezero_points(operand)podem ser diferentes dequantization_dimension(result),scales(result)ezero_points(result), respectivamente, caso contrário.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Para todos os
demaxes(operand):dim(operand, d) = 1oudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Se
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Se
dim(operand, quantization_dimension(operand)) = 1, entãoscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Exemplos
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semântica
Essa operação é funcionalmente idêntica à operação de convolução, mas o padding é especificado dinamicamente via padding.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tensor ou tensor quantizado | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
Tensor bidimensional do tipo inteiro | (C4) |
| (I4) | window_strides |
Constante de tensor unidimensional do tipo si64 |
(C2-C3) |
| (I5) | lhs_dilation |
Constante de tensor unidimensional do tipo si64 |
(C5-C6) |
| (I6) | rhs_dilation |
Constante de tensor unidimensional do tipo si64 |
(C7-C8) |
| (I7) | window_reversal |
Constante de tensor unidimensional do tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
constante do tipo si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
constante do tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
constante do tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante do tipo si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
constante do tipo si64 |
(C20) |
| (I15) | output_feature_dimension |
constante do tipo si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C19-C20) |
| (I17) | feature_group_count |
constante do tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante do tipo si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
número variádico de enums de DEFAULT, HIGH e HIGHEST |
(C24) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C25-C27), (C29), (C31-C33) |
Restrições
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)é definido como:dim(lhs, input_batch_dimension) / batch_group_countseresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)seresult_dim = output_feature_dimension.num_windowscaso contrário, em que:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Se a operação usar tensores não quantizados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Se a operação usar tensores quantizados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Se
is_per_axis_quantized(rhs), entãoquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Se
is_per_axis_quantized(result), entãoquantization_dimension(result) = output_feature_dimension. - Se
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Se
is_per_tensor_quantized(rhs), entãois_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Exemplos
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semântica
Essa operação é funcionalmente idêntica à operação
gather, com o slice_sizes especificado dinamicamente como um valor.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
tensor de tipo inteiro | (C2), (C3), (C13) |
| (I3) | slice_sizes |
Tensor unidimensional de tipo inteiro | (C8), (C11-C13) |
| (I4) | offset_dims |
Constante de tensor unidimensional do tipo si64 |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
Constante de tensor unidimensional do tipo si64 |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
Constante de tensor unidimensional do tipo si64 |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
constante do tipo si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
constante do tipo i1 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5), (C13-C14) |
Restrições
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)em que:batch_dim_sizes = shape(start_indices), exceto que o tamanho da dimensão destart_indicescorrespondente aindex_vector_dimnão está incluído.offset_dim_sizes = shape(slice_sizes), exceto que os tamanhos de dimensão emslice_sizescorrespondentes acollapsed_slice_dimsnão são incluídos.combinecolocabatch_dim_sizesem eixos correspondentes abatch_dimseoffset_dim_sizesem eixos correspondentes aoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Exemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semântica
Essa operação é funcionalmente idêntica à operação iota, mas a forma do resultado é especificada dinamicamente via output_shape.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | output_shape |
Tensor unidimensional de tipo inteiro | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C2) |
Restrições
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Exemplos
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semântica
Essa operação é funcionalmente idêntica à operação pad, mas com edge_padding_low, edge_padding_high e interior_padding especificados dinamicamente como valores.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor de dimensão zero ou tensor quantizado por tensor | (C1) |
| (I3) | edge_padding_low |
Tensor unidimensional de tipo inteiro | (C1), (C4) |
| (I4) | edge_padding_high |
Tensor unidimensional de tipo inteiro | (C1), (C4) |
| (I5) | interior_padding |
Tensor unidimensional de tipo inteiro | (C2-C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3-C6) |
Restrições
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Exemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semântica
Essa operação é funcionalmente idêntica à operação reshape, mas o formato do resultado é especificado dinamicamente via output_shape.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1-C3) |
| (I2) | output_shape |
Tensor unidimensional de tipo inteiro | (C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1 a C4) |
Restrições
- (C1)
element_type(result)é dado por:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), exceto quequantization_dimension(operand)equantization_dimension(result)podem ser diferentes.
- (C2)
size(operand) = size(result). - (C3) Se
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semântica
Extrai uma fatia do operand usando índices iniciais calculados dinamicamente e produz um tensor result. start_indices contém os índices iniciais da fatia para cada dimensão sujeita a possível ajuste, e slice_sizes contém os tamanhos da fatia para cada dimensão. De maneira mais formal, result[result_index] = operand[operand_index], em que:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2), (C4) |
| (I2) | start_indices |
número variádico de tensores de dimensão zero do tipo inteiro | (C2), (C3) |
| (I3) | slice_sizes |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1), (C5) |
Restrições
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Exemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semântica
Produz um tensor result igual ao tensor operand, exceto que a fração que começa em start_indices é atualizada com os valores em update.
Formalmente, result[result_index] é definido como:
update[update_index]se0 <= update_index < shape(update)em que:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
operand[result_index]caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1-C4), (C6) |
| (I2) | update |
tensor ou tensor quantizado por tensor | (C2), (C3), (C6) |
| (I3) | start_indices |
número variádico de tensores de dimensão zero do tipo inteiro | (C4), (C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Exemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semântica
Executa uma operação exponencial de elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
expdo IEEE-754. - Para números complexos: exponencial complexa.
- Para tipos quantizados:
dequantize_op_quantize(exponential, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semântica
Executa uma operação exponencial de subtração de um elemento por elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
expm1do IEEE-754. - Para números complexos: exponencial complexa menos um.
- Para tipos quantizados:
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
Semântica
Executa as transformadas de Fourier direta e inversa para entradas/saídas reais e complexas.
fft_type é um destes:
FFT: encaminha a FFT complexa para complexa.IFFT: FFT complexa para complexa inversa.RFFT: encaminha a FFT real para complexa.IRFFT: FFT inversa de real para complexo (ou seja, recebe complexo e retorna real).
Mais formalmente, considerando a função fft que usa tensores unidimensionais de tipos complexos como entrada, produz tensores unidimensionais dos mesmos tipos como saída e calcula a transformação de Fourier discreta:
Para fft_type = FFT, result é definido como o resultado final de uma série de cálculos L em que L = size(fft_length). Por exemplo, para L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Além disso, considerando a função ifft, que tem a mesma assinatura de tipo e calcula o inverso de fft:
Para fft_type = IFFT, result é definido como o inverso dos cálculos de fft_type = FFT. Por exemplo, para L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
Além disso, considerando a função rfft, que usa tensores unidimensionais de tipos de ponto flutuante, produz tensores unidimensionais de tipos complexos da mesma semântica de ponto flutuante e funciona da seguinte maneira:
rfft(real_operand) = truncated_resultem quecomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
Quando a transformação discreta de Fourier é calculada para operandos reais, os primeiros N/2 + 1 elementos do resultado definem de maneira inequívoca o restante do resultado. Portanto, o resultado de rfft é truncado para evitar o cálculo de elementos redundantes.
Para fft_type = RFFT, result é definido como o resultado final de uma série de cálculos L em que L = size(fft_length). Por exemplo, para L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Por fim, considerando a função irfft, que tem a mesma assinatura de tipo e calcula o inverso de rfft:
Para fft_type = IRFFT, result é definido como o inverso dos cálculos de fft_type = RFFT. Por exemplo, para L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
enum de FFT, IFFT, RFFT e IRFFT |
(C2), (C5) |
| (I3) | fft_length |
Constante de tensor unidimensional do tipo si64 |
(C1), (C3), (C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo | (C2), (C4), (C5) |
Restrições
- (C1)
size(fft_length) <= rank(operand). - (C2) A relação entre os tipos de elementos
operanderesultvaria:- Se
fft_type = FFT,element_type(operand)eelement_type(result)tiverem o mesmo tipo complexo. - Se
fft_type = IFFT,element_type(operand)eelement_type(result)tiverem o mesmo tipo complexo. - Se
fft_type = RFFT,element_type(operand)for um tipo de ponto flutuante eelement_type(result)for um tipo complexo da mesma semântica de ponto flutuante. - Se
fft_type = IRFFT,element_type(operand)for um tipo complexo eelement_type(result)for um tipo de ponto flutuante da mesma semântica de ponto flutuante.
- Se
- (C3)
1 <= size(fft_length) <= 3. - (C4) Se entre
operanderesulthouver um tensorrealde um tipo de ponto flutuante, entãoshape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand), exceto:- Se
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Se
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Se
Exemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
andar
Semântica
Realiza o piso elemento a elemento do tensor operand e produz um tensor result.
Implementa a operação roundToIntegralTowardNegative da especificação IEEE-754. Para tipos quantizados, realiza
dequantize_op_quantize(floor, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
gather
Semântica
Coleta slices do tensor operand dos deslocamentos especificados em start_indices e produz um tensor result.
O diagrama a seguir mostra como os elementos em result são mapeados em elementos em operand usando um exemplo concreto. O diagrama escolhe alguns exemplos de índices result e explica em detalhes a quais índices operand eles correspondem.
Mais formalmente, result[result_index] = operand[operand_index], em que:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexé definido como:start_indices[bi0, ..., :, ..., biN]em quebisão elementos individuais embatch_indexe:é inserido no índiceindex_vector_dim, seindex_vector_dim<rank(start_indices).[start_indices[batch_index]]caso contrário.
- Para
d_operandemaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])sed_operand = start_index_map[d_start].full_start_index[d_operand] = 0caso contrário.
- Para
d_operandemaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sed_operand = operand_batching_dims[i_batching]ed_start = start_indices_batching_dims[i_batching].full_batching_index[d_operand] = 0caso contrário.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN]em queoisão elementos individuais emoffset_indexe0é inserido nos índices decollapsed_slice_dimseoperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Se indices_are_sorted for true, a implementação poderá presumir que start_indices estão classificados em relação a start_index_map. Caso contrário, o comportamento será indefinido. De maneira mais formal, para todo i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
tensor de tipo inteiro | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
Constante de tensor unidimensional do tipo si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
Constante de tensor unidimensional do tipo si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
Constante de tensor unidimensional do tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
Constante de tensor unidimensional do tipo si64 |
(C13-C17) |
| (I7) | start_index_map |
Constante de tensor unidimensional do tipo si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
constante do tipo si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
Constante de tensor unidimensional do tipo si64 |
(C9), (C12), (C20-C22) |
| (I10) | indices_are_sorted |
constante do tipo i1 |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5), (C22-C23) |
Restrições
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)em que:batch_dim_sizes = shape(start_indices), exceto que o tamanho da dimensão destart_indicescorrespondente aindex_vector_dimnão está incluído.offset_dim_sizes = slice_sizes, exceto que os tamanhos de dimensão emslice_sizescorrespondentes acollapsed_slice_dimseoperand_batching_dimsnão estão incluídos.combinecolocabatch_dim_sizesem eixos correspondentes abatch_dimseoffset_dim_sizesem eixos correspondentes aoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Exemplos
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semântica
Produz o tamanho do dimension especificado do operand. De maneira mais formal, result = dim(operand, dimension). A semântica se refere apenas ao componente de forma do tipo. O tipo de elemento pode ser qualquer coisa.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1) |
| (I2) | dimension |
constante do tipo si64 |
(C1) |
Saídas
| Nome | Tipo |
|---|---|
result |
Tensor de dimensão zero do tipo si32 |
Restrições
- (C1)
0 <= dimension < rank(operand).
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
Semântica
Extrai o elemento na posição index da tupla operand e produz um result. Mais formalmente, result = operand[index].
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tuple | (C1), (C2) |
| (I2) | index |
constante do tipo si32 |
(C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
qualquer valor | (C2) |
Restrições
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Exemplos
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
se
Semântica
Produz a saída da execução de exatamente uma função de true_branch ou
false_branch, dependendo do valor de pred. Mais formalmente, result =
pred ? true_branch() : false_branch().
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | pred |
Tensor de dimensão zero do tipo i1 |
|
| (I2) | true_branch |
função | (C1-C3) |
| (I3) | false_branch |
função | (C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores, tensores quantizados ou tokens | (C3) |
Restrições
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Exemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
Semântica
Extrai a parte imaginária, elemento por elemento, do operand e produz um tensor result. De maneira mais formal, para cada elemento x:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo ponto flutuante | (C1), (C2) |
Restrições
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)é definido como:complex_element_type(element_type(operand))seis_complex(operand).element_type(operand)caso contrário.
Exemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
infeed
Semântica
Lê dados do feed e produz results.
A semântica de infeed_config é definida pela implementação.
results consiste em valores de payload que vêm primeiro e um token que vem por último. No futuro, planejamos dividir o payload e o token em duas saídas separadas para melhorar a clareza (#670).
Entradas
| Rótulo | Nome | Tipo |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
constante do tipo string |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores, tensores quantizados ou tokens | (C1-C3) |
Restrições
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])ouis_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Exemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semântica
Preenche um tensor output com valores em ordem crescente, começando do zero, ao longo da dimensão iota_dimension. Mais formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
output |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
0 <= iota_dimension < rank(output).
Exemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semântica
Realiza uma verificação elemento a elemento para saber se o valor em x é finito (ou seja, não é +Inf, -Inf nem NaN) e produz um tensor y. Implementa a operação isFinite
da especificação IEEE-754. Para tipos quantizados, o resultado é sempre true.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | x |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
y |
tensor do tipo booleano | (C1) |
Restrições
- (C1)
shape(x) = shape(y).
Exemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
Semântica
Executa a operação de logaritmo elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
logdo IEEE-754. - Para números complexos: logaritmo complexo.
- Para tipos quantizados:
dequantize_op_quantize(log, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semântica
Executa uma operação de logaritmo por elemento mais um no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
logp1do IEEE-754. - Para números complexos:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - Para tipos quantizados:
dequantize_op_quantize(log_plus_one, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semântica
Executa a operação logística elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
division(1, addition(1, exp(-x)))do IEEE-754. - Para números complexos: logística complexa.
- Para tipos quantizados:
dequantize_op_quantize(logistic, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semântica
Aplica uma função de mapa computation a inputs ao longo do dimensions e produz um tensor result.
Mais formalmente, result[result_index] = computation(inputs...[result_index]).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1 a C4) |
| (I2) | dimensions |
Constante de tensor unidimensional do tipo si64 |
(C3) |
| (I3) | computation |
função | (C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1), (C4) |
Restrições
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationtem o tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, em queEi = element_type(inputs[i])eE' = element_type(result).
Exemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
máximo
Semântica
Executa a operação de máximo por elemento nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: número inteiro máximo.
- Para pontos flutuantes:
maximumdo IEEE-754. - Para números complexos: máximo lexicográfico para o par
(real, imaginary). Impor uma ordenação em números complexos envolve semântica surpreendente. Por isso, no futuro, planejamos remover o suporte a números complexos para essa operação (#560). - Para tipos quantizados:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
mínimo
Semântica
Executa a operação de mínimo por elemento nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: mínimo de número inteiro.
- Para pontos flutuantes:
minimumdo IEEE-754. - Para números complexos: mínimo lexicográfico para o par
(real, imaginary). Impor uma ordenação em números complexos envolve semântica surpreendente. Por isso, no futuro, planejamos remover o suporte a números complexos para essa operação (#560). - Para tipos quantizados:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
multiplicar
Semântica
Executa o produto elemento a elemento de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: multiplicação de números inteiros.
- Para pontos flutuantes:
multiplicationdo IEEE-754. - Para números complexos: multiplicação complexa.
- Para tipos quantizados:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
negate
Semântica
Executa a negação de elemento a elemento do tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros com sinal: negação de número inteiro.
- Para números inteiros sem sinal: bitcast para número inteiro com sinal, negação de número inteiro, bitcast de volta para número inteiro sem sinal.
- Para pontos flutuantes:
negatedo IEEE-754. - Para números complexos: negação complexa.
- Para tipos quantizados:
dequantize_op_quantize(negate, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
não
Semântica
Executa um NOT elemento a elemento do tensor operand e produz um tensor result.
Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: NOT lógico.
- Para números inteiros: NOT bit a bit.
Argumentos
| Nome | Tipo | Restrições |
|---|---|---|
operand |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result).
Exemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
Semântica
Garante que as operações que produzem o operand sejam executadas antes de qualquer
operação que dependa do result e impede que as transformações do compilador
movam operações pela barreira. Fora isso, a operação é uma identidade, ou seja, result = operand.
Argumentos
| Nome | Tipo | Restrições |
|---|---|---|
operand |
número variádico de tensores, tensores ou tokens quantizados por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
número variádico de tensores, tensores ou tokens quantizados por tensor | (C1) |
Restrições
- (C1)
type(operand...) = type(result...).
Exemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
ou
Semântica
Executa OR bit a bit de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: OR bit a bit.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro ou booleano | (C1) |
| (I2) | rhs |
tensor de tipo inteiro ou booleano | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro ou booleano | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
outfeed
Semântica
Grava inputs no outfeed e produz um token result.
A semântica de outfeed_config é definida pela implementação.
Entradas
| Rótulo | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados |
| (I2) | token |
token |
| (I3) | outfeed_config |
constante do tipo string |
Saídas
| Nome | Tipo |
|---|---|
result |
token |
Exemplos
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
pad
Semântica
Expande operand com padding ao redor do tensor e entre os elementos dele com o padding_value especificado.
edge_padding_low e edge_padding_high especificam a quantidade de padding adicionada na extremidade inferior (próxima ao índice 0) e na extremidade superior (próxima ao índice mais alto) de cada dimensão, respectivamente. A quantidade de padding pode ser negativa, em que o valor absoluto do padding negativo indica o número de elementos a serem removidos da dimensão especificada.
interior_padding especifica a quantidade de padding adicionada entre dois elementos em cada dimensão, que não pode ser negativa. O padding interno ocorre antes do padding de borda, de modo que um padding de borda negativo remove elementos do operando com padding interno.
Formalmente, result[result_index] é definido como:
operand[operand_index]seresult_index = edge_padding_low + operand_index * (interior_padding + 1).padding_valuecaso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor de dimensão zero ou tensor quantizado por tensor | (C1) |
| (I3) | edge_padding_low |
Constante de tensor unidimensional do tipo si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
Constante de tensor unidimensional do tipo si64 |
(C1), (C4) |
| (I5) | interior_padding |
Constante de tensor unidimensional do tipo si64 |
(C2-C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3-C6) |
Restrições
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Exemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semântica
Produz partition_id do processo atual.
Saídas
| Nome | Tipo |
|---|---|
result |
Tensor de dimensão 0 do tipo ui32 |
Exemplos
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
Semântica
Realiza uma contagem elemento a elemento do número de bits definidos no tensor operand e produz um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result).
Exemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
potência
Semântica
Realiza a exponenciação elemento a elemento do tensor lhs pelo tensor rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: exponenciação de números inteiros.
- Para pontos flutuantes:
powdo IEEE-754. - Para números complexos: exponenciação complexa.
- Para tipos quantizados:
dequantize_op_quantize(power, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semântica
Extrai a parte real, elemento por elemento, do operand e produz um tensor result. De maneira mais formal, para cada elemento x:
real(x) = is_complex(x) ? real_part(x) : x.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo ponto flutuante | (C1), (C2) |
Restrições
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)é definido como:complex_element_type(element_type(operand))seis_complex(operand).element_type(operand)caso contrário.
Exemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
recv
Semântica
Recebe dados de um canal com channel_id e produz results.
Se is_host_transfer for true, a operação vai transferir dados do host. Caso contrário, ele transfere dados de outro dispositivo com base nos valores de source_target_pairs. Essa flag duplica as informações fornecidas em
channel_type. Por isso, no futuro, planejamos manter apenas uma delas
(#666). Se is_host_transfer
= false e source_target_pairs for None ou vazio, será considerado
comportamento indefinido.
results consiste em valores de payload que vêm primeiro e um token que vem
por último. No futuro, planejamos dividir o payload e o token em duas saídas separadas para melhorar a clareza (#670).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
constante do tipo si64 |
|
| (I3) | channel_type |
enum de DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
constante do tipo i1 |
(C5-C6) |
| (I5) | source_target_pairs |
Constante de tensor bidimensional do tipo si64 |
(C1-C4), (C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores, tensores quantizados ou tokens | (C2-C4) |
Restrições
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, em queNé definido como:num_replicassecross_replicafor usado.num_partitionssecross_partitionfor usado.
- (C5)
channel_typeé definido como:DEVICE_TO_HOSTseis_host_transfer = true,DEVICE_TO_DEVICEcaso contrário.
Exemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
reduce
Semântica
Aplica uma função de redução body a inputs e init_values ao longo da dimensions e produz tensores results.
A ordem das reduções é definida pela implementação, o que significa que body e init_values precisam formar um monóide para garantir que a operação produza os mesmos resultados para todas as entradas em todas as implementações. No entanto, essa condição não é válida para muitas reduções conhecidas. Por exemplo, a adição de ponto flutuante para body e zero para init_values não formam um monóide porque a adição de ponto flutuante não é associativa.
Mais formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted), em que:
input_slices = inputs...[j0, ..., :, ..., jR-1], em que:são inseridos emdimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)para alguma árvore bináriascheduleem que:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
scheduleé uma árvore binária completa definida pela implementação cuja travessia em ordem consiste em:- Valores de
input_slices_converted...[index]para todos osindexemindex_space(input_slices_converted)na ordem lexicográfica crescente deindex. - Intercalados com uma quantidade de
init_values_converteddefinida pela implementação em posições definidas pela implementação.
- Valores de
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1-C4), (C6), (C7) |
| (I2) | init_values |
número variádico de tensores de dimensão zero ou tensores quantizados por tensor | (C2), (C3) |
| (I3) | dimensions |
Constante de tensor unidimensional do tipo si64 |
(C4), (C5), (C7) |
| (I4) | body |
função | (C6) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C3), (C7), (C8) |
Restrições
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodytem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)em queis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...), exceto que os tamanhos de dimensão deinputs...correspondentes adimensionsnão estão incluídos. - (C8)
element_type(results[i]) = Eipara todos osiem[0,N).
Exemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
Semântica
Realiza a conversão elemento a elemento de operand para outro tipo de ponto flutuante que usa exponent_bits e mantissa_bits e volta para o tipo de ponto flutuante original, produzindo um tensor output.
Mais formalmente:
- Os bits da mantissa do valor original são atualizados para arredondar o valor original para o valor mais próximo representável com
mantissa_bitsusando a semânticaroundToIntegralTiesToEven. - Em seguida, se
mantissa_bitsfor menor que o número de bits da mantissa do valor original, os bits da mantissa serão truncados paramantissa_bits. - Em seguida, se os bits de expoente do resultado intermediário não se encaixarem no intervalo fornecido por
exponent_bits, o resultado intermediário vai transbordar para o infinito usando o sinal original ou vai entrar em underflow para zero usando o sinal original. - Para tipos quantizados, realiza
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
| (I2) | exponent_bits |
constante do tipo si32 |
(C2) |
| (I3) | mantissa_bits |
constante do tipo si32 |
(C3) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Exemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semântica
Em cada grupo de processos na grade de processos do StableHLO, realiza redução usando computations nos valores do tensor operand de cada processo, divide o resultado da redução ao longo de scatter_dimension em partes e dispersa as partes divididas entre os processos para produzir o result.
A operação divide a grade de processo do StableHLO em process_groups, que é definida da seguinte forma:
cross_replica(replica_groups)sechannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sechannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sechannel_id > 0 and use_global_device_ids = true.
Depois, em cada process_group:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]para todos ossenderemprocess_group, em quereceiver_index = process_group.index(receiver).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
constante do tipo si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
Constante de tensor bidimensional do tipo si64 |
(C3-C5) |
| (I4) | channel_id |
constante do tipo si64 |
(C6) |
| (I5) | use_global_device_ids |
constante do tipo i1 |
(C6) |
| (I6) | computation |
função | (C7) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C8-C9) |
Restrições
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)é definido como:num_replicassecross_replicafor usado.num_replicassecross_replica_and_partitionfor usado.num_processesseflattened_idsfor usado.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Se
use_global_device_ids = true, entãochannel_id > 0. - (C7)
computationtem o tipo(tensor<E>, tensor<E>) -> (tensor<E>)em queis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand), exceto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semântica
Aplica uma função de redução body a janelas de inputs e init_values e gera results.
O diagrama a seguir mostra como os elementos em results... são calculados com base em inputs... usando um exemplo concreto.
De maneira mais formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (consulte reduce), em que:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
número variádico de tensores de dimensão zero ou tensores quantizados por tensor | (C1), (C13) |
| (I3) | window_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C4), (C5), (C15) |
| (I4) | window_strides |
Constante de tensor unidimensional do tipo si64 |
(C6), (C7), (C15) |
| (I5) | base_dilations |
Constante de tensor unidimensional do tipo si64 |
(C8), (C9), (C15) |
| (I6) | window_dilations |
Constante de tensor unidimensional do tipo si64 |
(C10), (C11), (C15) |
| (I7) | padding |
Constante de tensor bidimensional do tipo si64 |
(C12), (C15) |
| (I8) | body |
função | (C13) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C1), (C14-C16) |
Restrições
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodytem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)em queis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsem que:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eipara todos osiem[0,N).
Exemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
resto
Semântica
Realiza o restante elemento a elemento dos tensores dividendo lhs e divisor rhs e gera um tensor result.
Mais formalmente, o sinal do resultado é extraído do dividendo, e o valor absoluto do resultado é sempre menor que o valor absoluto do divisor.
O restante é calculado como lhs - d * rhs, em que d é dado por:
- Para números inteiros:
stablehlo.divide(lhs, rhs). - Para números de ponto flutuante:
division(lhs, rhs)do IEEE-754 com atributo de arredondamentoroundTowardZero. - Para números complexos: a ser definido (#997).
- Para tipos quantizados:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
Para tipos de elementos de ponto flutuante, essa operação contrasta com a operação remainder da especificação IEEE-754, em que d é um valor integral mais próximo do valor exato de lhs/rhs com empates para par.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
Semântica
Produz replica_id do processo atual.
Saídas
| Nome | Tipo |
|---|---|
result |
Tensor de dimensão 0 do tipo ui32 |
Exemplos
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
reshape
Semântica
Muda o formato do tensor operand para um tensor result. Conceitualmente, isso equivale a manter a mesma representação canônica, mas possivelmente mudar a forma, por exemplo, de tensor<2x3xf32> para tensor<3x2xf32> ou tensor<6xf32>.
Mais formalmente, result[result_index] = operand[operand_index] em que result_index e operand_index têm a mesma posição na ordenação lexicográfica de index_space(result) e index_space(operand).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1-C3) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1-C3) |
Restrições
- (C1)
element_type(result)é dado por:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), exceto quequantization_dimension(operand)equantization_dimension(result)podem ser diferentes.
- (C2)
size(operand) = size(result). - (C3) Se
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
anular
Semântica
Inverte a ordem dos elementos no operand ao longo da dimensions especificada e produz um tensor result. De maneira mais formal, result[result_index] = operand[operand_index], em que:
operand_index[d] = dim(result, d) - result_index[d] - 1sedemdimensions.operand_index[d] = result_index[d]caso contrário.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C3) |
| (I2) | dimensions |
Constante de tensor unidimensional do tipo si64 |
(C2), (C3) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1), (C3) |
Restrições
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Exemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semântica
Gera números aleatórios usando o algoritmo rng_distribution e produz um tensor result de um formato específico shape.
Se rng_distribution = UNIFORM, os números aleatórios serão gerados seguindo a distribuição uniforme no intervalo [a, b). Se a >= b, o comportamento será indefinido.
Se rng_distribution = NORMAL, os números aleatórios serão gerados seguindo a distribuição normal com média = a e desvio padrão = b.
Se b < 0, o comportamento será indefinido.
A maneira exata de gerar números aleatórios é definida pela implementação. Por exemplo, eles podem ser deterministas ou não, e podem usar ou não um estado oculto.
Em conversas com muitas partes interessadas, essa operação foi considerada efetivamente descontinuada. Por isso, no futuro, planejamos remover ela (#597).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | a |
Tensor de dimensão zero de tipo inteiro, booleano ou de ponto flutuante | (C1), (C2) |
| (I2) | b |
Tensor de dimensão zero de tipo inteiro, booleano ou de ponto flutuante | (C1), (C2) |
| (I3) | shape |
Constante de tensor unidimensional do tipo si64 |
(C3) |
| (I4) | rng_distribution |
enum de UNIFORM e NORMAL |
(C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, booleano ou de ponto flutuante | (C1-C3) |
Restrições
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Se
rng_distribution = NORMAL, entãois_float(a). - (C3)
shape(result) = shape.
Exemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semântica
Retorna um output preenchido com bits aleatórios uniformes e um estado de saída atualizado output_state usando o algoritmo gerador de números pseudoaleatórios rng_algorithm, dado um estado inicial initial_state. A saída é uma função determinista de initial_state, mas não é garantido que seja determinista entre implementações.
rng_algorithm é um destes:
DEFAULT: algoritmo definido pela implementação.THREE_FRY: variante definida pela implementação do algoritmo Threefry.*PHILOX: variante definida pela implementação do algoritmo Philox.*
* Consulte: Salmon et al. SC 2011. Números aleatórios paralelos: tão fácil quanto contar até três.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | rng_algorithm |
enum de DEFAULT, THREE_FRY e PHILOX |
(C2) |
| (I2) | initial_state |
Tensor unidimensional do tipo ui64 |
(C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
output_state |
Tensor unidimensional do tipo ui64 |
(C1) |
output |
tensor de tipo inteiro ou de ponto flutuante |
Restrições
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)é definido como:- definido pela implementação se
rng_algorithm = DEFAULT. 2serng_algorithm = THREE_FRY.2ou3serng_algorithm = PHILOX.
- definido pela implementação se
Exemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semântica
Realiza o arredondamento de elemento por elemento para o número inteiro mais próximo, desfazendo empates longe de zero, no tensor operand e produz um tensor result. Implementa a operação roundToIntegralTiesToAway da especificação IEEE-754. Para tipos quantizados, realiza dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semântica
Realiza o arredondamento de elemento por elemento para o número inteiro mais próximo, resolvendo empates para o número inteiro par, no tensor operand e produz um tensor result. Implementa a operação roundToIntegralTiesToEven da especificação IEEE-754. Para tipos quantizados, realiza
dequantize_op_quantize(round_nearest_even, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semântica
Executa a operação de raiz quadrada recíproca elemento a elemento no tensor operand e gera um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
rSqrtdo IEEE-754. - Para números complexos: raiz quadrada recíproca complexa.
- Para tipos quantizados:
dequantize_op_quantize(rsqrt, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
dispersão
Semântica
Produz tensores results iguais aos tensores inputs, exceto que várias fatias especificadas por scatter_indices são atualizadas com os valores updates usando update_computation.
O diagrama a seguir mostra como os elementos em updates... são mapeados em elementos em results... usando um exemplo concreto. O diagrama escolhe alguns exemplos de índices updates... e explica em detalhes a quais índices results... eles correspondem.
De maneira mais formal, para todo update_index em index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexé definido como:scatter_indices[si0, ..., :, ..., siN]em quesisão elementos individuais emupdate_scatter_indexe:é inserido no índiceindex_vector_dim, seindex_vector_dim<rank(scatter_indices).[scatter_indices[update_scatter_index]]caso contrário.
- Para
d_inputemaxes(inputs[0]),full_start_index[d_input] = start_index[d_start]sed_input = scatter_dims_to_operand_dims[d_start].full_start_index[d_input] = 0caso contrário.
- Para
d_inputemaxes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sed_input = input_batching_dims[i_batching]ed_start = scatter_indices_batching_dims[i_batching].full_batching_index[d_input] = 0caso contrário.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN]em quewisão elementos individuais emupdate_window_indexe0é inserido nos índices deinserted_window_dimseinput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
Assim, results = exec(schedule, inputs), em que:
scheduleé uma permutação deindex_space(updates[0])definida pela implementação.exec([update_index, ...], results) = exec([...], updated_results)onde:- Se
result_indexestiver dentro dos limites deshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsé uma cópia deresultscomresults...[result_index]definido comoupdated_values....- Como alternativa, faça o seguinte:
updated_results = results.
- Se
exec([], results) = results.
Se indices_are_sorted for true, a implementação poderá presumir que scatter_indices estão classificados em relação a scatter_dims_to_operand_dims. Caso contrário, o comportamento será indefinido. De maneira mais formal, para todo i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Se unique_indices for true, a implementação poderá presumir que todos os índices result_index que estão sendo dispersos são únicos. Se unique_indices for true, mas os índices que estão sendo dispersos não forem exclusivos, o comportamento será indefinido.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
tensor de tipo inteiro | (C4), (C15), (C19), (C22) |
| (I3) | updates |
número variádico de tensores ou tensores quantizados por tensor | (C3-C6), (C8) |
| (I4) | update_window_dims |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
Constante de tensor unidimensional do tipo si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
Constante de tensor unidimensional do tipo si64 |
(C19-C21) |
| (I9) | index_vector_dim |
constante do tipo si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
constante do tipo i1 |
|
| (I11) | unique_indices |
constante do tipo i1 |
|
| (I12) | update_computation |
função | (C23) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C24-C25) |
Restrições
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)em que:update_scatter_dim_sizes = shape(scatter_indices), exceto que o tamanho da dimensão descatter_indicescorrespondente aindex_vector_dimnão está incluído.update_window_dim_sizes <= shape(inputs[0]), exceto que os tamanhos de dimensão eminputs[0]correspondentes ainserted_window_dimseinput_batching_dimsnão estão incluídos.combinecolocaupdate_scatter_dim_sizesnos eixos correspondentes aupdate_scatter_dimseupdate_window_dim_sizesnos eixos correspondentes aupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationtem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), em queis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eipara todos osiem[0,N).
Exemplos
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
selecionar
Semântica
Produz um tensor result em que cada elemento é selecionado do tensor on_true ou on_false com base no valor do elemento correspondente de pred.
De maneira mais formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], em que pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. Para tipos quantizados, realiza
dequantize_select_quantize(pred, on_true, on_false, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | pred |
tensor do tipo i1 |
(C1) |
| (I2) | on_true |
tensor ou tensor quantizado por tensor | (C1-C2) |
| (I3) | on_false |
tensor ou tensor quantizado por tensor | (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C2) |
Restrições
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Exemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semântica
Dispersa os valores do tensor source usando scatter com base no resultado de reduce_window do tensor input usando select e produz um tensor result.
O diagrama a seguir mostra como os elementos em result são calculados com base em operand e source usando um exemplo concreto.
Mais formalmente:
selected_values = reduce_window_without_init(...)com as seguintes entradas:inputs = [operand].window_dimensions,window_stridesepadding, que são usados como estão.base_dilations = windows_dilations = 1.bodyé definido como:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;em que
E = element_type(operand)ereduce_window_without_initfuncionam exatamente comoreduce_window, exceto que oscheduledoreducesubjacente (consulte reduce) não inclui valores de inicialização. No momento, não está especificado o que acontece se a janela correspondente não tiver valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)em que:source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexseselected_values[source_index]tiver o elementooperanddeoperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tensor ou tensor quantizado por tensor | (C1), (C2) |
| (I3) | init_value |
Tensor de dimensão zero ou tensor quantizado por tensor | (C3) |
| (I4) | window_dimensions |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4), (C5) |
| (I5) | window_strides |
Constante de tensor unidimensional do tipo si64 |
(C2), (C6), (C7) |
| (I6) | padding |
Constante de tensor bidimensional do tipo si64 |
(C2), (C8) |
| (I7) | select |
função | (C9) |
| (I8) | scatter |
função | (C10) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C11-C12) |
Restrições
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windowsem que:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selecttem o tipo(tensor<E>, tensor<E>) -> tensor<i1>em queE = element_type(operand). - (C10)
scattertem o tipo(tensor<E>, tensor<E>) -> tensor<E>em queis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Exemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
enviar
Semântica
Envia inputs para um canal channel_id. As entradas são enviadas para outros dispositivos na ordem especificada por source_target_pairs. A operação produz um token result.
Se is_host_transfer for true, a operação vai transferir dados para o host. Caso contrário, ele transfere dados para outro dispositivo com base nos valores de source_target_pairs. Essa flag duplica as informações fornecidas em
channel_type. Por isso, no futuro, planejamos manter apenas uma delas
(#666). Se is_host_transfer
= false e source_target_pairs for None ou vazio, será considerado
comportamento indefinido.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados | |
| (I2) | token |
token |
|
| (I3) | channel_id |
constante do tipo si64 |
|
| (I4) | channel_type |
enum de DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
constante do tipo i1 |
(C5-C6) |
| (I6) | source_target_pairs |
Constante de tensor bidimensional do tipo si64 |
(C1-C4), (C6) |
Saídas
| Nome | Tipo |
|---|---|
result |
token |
Restrições
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, em queNé definido como:num_replicassecross_replicafor usado.num_partitionssecross_partitionfor usado.
- (C5)
channel_typeé definido como:DEVICE_TO_HOSTseis_host_transfer = true,DEVICE_TO_DEVICEcaso contrário.
Exemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
Semântica
Executa uma operação de deslocamento à esquerda bit a bit no tensor lhs pelo número rhs de bits e produz um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro | (C1) |
| (I2) | rhs |
tensor de tipo inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
Semântica
Realiza uma operação de deslocamento aritmético à direita de elemento a elemento no tensor lhs por rhs número de bits e gera um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro | (C1) |
| (I2) | rhs |
tensor de tipo inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
Semântica
Executa a operação de deslocamento lógico à direita no tensor lhs por rhs bits e gera um tensor result.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro | (C1) |
| (I2) | rhs |
tensor de tipo inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
assinatura
Semântica
Retorna o sinal do elemento operand e produz um tensor result.
De maneira mais formal, para cada elemento x, a semântica pode ser expressa usando a sintaxe do Python da seguinte maneira:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para tipos quantizados, realiza
dequantize_op_quantize(sign, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de número inteiro com sinal, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de número inteiro com sinal, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semântica
Executa uma operação de seno elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
sindo IEEE-754. - Para números complexos: seno complexo.
- Para tipos quantizados:
dequantize_op_quantize(sine, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semântica
Extrai uma fatia do operand usando índices iniciais calculados estaticamente e produz um tensor result. start_indices contém os índices iniciais da fatia para cada dimensão, limit_indices contém os índices finais (exclusivos) da fatia para cada dimensão, e strides contém as etapas para cada dimensão.
Mais formalmente, result[result_index] = operand[operand_index] em que operand_index = start_indices + result_index * strides.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado por tensor | (C1-C3), (C5) |
| (I2) | start_indices |
Constante de tensor unidimensional do tipo si64 |
(C2), (C3), (C5) |
| (I3) | limit_indices |
Constante de tensor unidimensional do tipo si64 |
(C2), (C3), (C5) |
| (I4) | strides |
Constante de tensor unidimensional do tipo si64 |
(C2), (C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1), (C5) |
Restrições
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Exemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
classificar
Semântica
Classifica fatias unidimensionais de inputs ao longo da dimensão dimension juntas, de acordo com um comparator, e produz results.
Ao contrário de entradas semelhantes em outras operações, dimension permite valores negativos, com a semântica descrita abaixo. No futuro, isso pode ser proibido por motivos de consistência (#1377).
Se is_stable for verdadeiro, a classificação será estável, ou seja, a ordem relativa dos elementos considerados iguais pelo comparador será preservada. No caso em que há uma única entrada, dois elementos e1 e e2 são considerados iguais pelo comparador se e somente se comparator(e1, e2) = comparator(e2, e1) = false. Confira abaixo a formalização de como isso é generalizado para várias entradas.
De maneira mais formal, para todo result_index em index_space(results[0]):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1]em queriNsão elementos individuais emresult_indexe:é inserido emadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- em que
sortclassifica uma fração unidimensional em ordem não decrescente, esperando quecomparator_togetherretornetruese o argumento do lado esquerdo for menor que o segundo argumento do lado direito. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | inputs |
número variádico de tensores ou tensores quantizados por tensor | (C1-C5) |
| (I2) | dimension |
constante do tipo si64 |
(C4) |
| (I3) | is_stable |
constante do tipo i1 |
|
| (I4) | comparator |
função | (C5) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de tensores ou tensores quantizados por tensor | (C2), (C3) |
Restrições
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, em queR = rank(inputs[0]). - (C5)
comparatortem o tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, em queEi = element_type(inputs[i]).
Exemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semântica
Executa a operação de raiz quadrada elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
squareRootdo IEEE-754. - Para números complexos: raiz quadrada complexa.
- Para tipos quantizados:
dequantize_op_quantize(sqrt, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semântica
Realiza a subtração elemento a elemento de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: subtração de números inteiros.
- Para pontos flutuantes:
subtractiondo IEEE-754. - Para números complexos: subtração complexa.
- Para tipos quantizados:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
Semântica
Executa uma operação de tangente elemento a elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
tando IEEE-754. - Para números complexos: tangente complexa.
- Para tipos quantizados:
dequantize_op_quantize(tan, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semântica
Executa uma operação de tangente hiperbólica independente de elementos no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
tanhdo IEEE-754. - Para números complexos: tangente hiperbólica complexa.
- Para tipos quantizados:
dequantize_op_quantize(tanh, operand, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result).
Exemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
transpor
Semântica
Permuta as dimensões do tensor operand usando permutation e produz um tensor result. Mais formalmente, result[result_index] = operand[operand_index], em que result_index[d] = operand_index[permutation[d]].
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor ou tensor quantizado | (C1 a C4) |
| (I2) | permutation |
Constante de tensor unidimensional do tipo si64 |
(C2-C4) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor ou tensor quantizado | (C1), (C3-C4) |
Restrições
- (C1)
element_type(result)é dado por:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), exceto quequantization_dimension(operand)equantization_dimension(result)podem ser diferentes.
- (C2)
permutationé uma permutação derange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) Se
is_per_axis_quantized(result), entãoquantization_dimension(operand) = permutation(quantization_dimension(result)).
Exemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semântica
Resolve lotes de sistemas de equações lineares com matrizes de coeficientes triangulares inferiores ou superiores.
Mais formalmente, considerando a e b, result[i0, ..., iR-3, :, :] é a solução para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] quando left_side é true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] quando left_side é false, resolvendo a variável x em que op(a) é determinado por transpose_a, que pode ser um dos seguintes:
NO_TRANSPOSE: execute a operação usandoacomo está.TRANSPOSE: realiza a operação na transposição dea.ADJOINT: executa a operação na transposta conjugada dea.
Os dados de entrada são lidos apenas do triângulo inferior de a se lower for true ou do triângulo superior de a, caso contrário. Os dados de saída são retornados no mesmo triângulo. Os valores no outro triângulo são definidos pela implementação.
Se unit_diagonal for verdadeiro, a implementação poderá presumir que os elementos diagonais de a são iguais a 1. Caso contrário, o comportamento será indefinido.
Para tipos quantizados, realiza
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | a |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1-C3) |
| (I2) | b |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1 a C4) |
| (I3) | left_side |
constante do tipo i1 |
(C3) |
| (I4) | lower |
constante do tipo i1 |
|
| (I5) | unit_diagonal |
constante do tipo i1 |
|
| (I6) | transpose_a |
enum de NO_TRANSPOSE, TRANSPOSE e ADJOINT |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) A relação entre
shape(a)eshape(b)é definida da seguinte maneira:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Exemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semântica
Produz uma tupla result dos valores val.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | val |
número variádico de valores | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tuple | (C1) |
Restrições
- (C1)
resulttem o tipotuple<E0, ..., EN-1>em queEi = type(val[i]).
Exemplos
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
Semântica
Realiza a conversão elemento a elemento do tensor quantizado operand para um tensor de ponto flutuante result de acordo com os parâmetros de quantização definidos pelo tipo operand.
Mais formalmente, result = dequantize(operand).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor quantizado | (C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo ponto flutuante | (C1), (C2) |
Restrições
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Exemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
Semântica
Realiza a conversão de elemento a elemento do tensor de ponto flutuante ou quantizado
operand para um tensor quantizado result de acordo com os parâmetros de quantização
definidos pelo tipo result.
Mais formalmente,
- Se
is_float(operand):result = quantize(operand, type(result)).
- Se
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
tensor de ponto flutuante ou tipo quantizado | (C1), (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor quantizado | (C1), (C2) |
Restrições
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Exemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
enquanto
Semântica
Produz a saída da execução da função body zero ou mais vezes enquanto a função cond gera true. De maneira mais formal, a semântica pode ser expressa usando a sintaxe do Python da seguinte maneira:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
O comportamento de um loop infinito será definido (#383).
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | operand |
número variádico de valores | (C1-C3) |
| (I2) | cond |
função | (C1) |
| (I3) | body |
função | (C2) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
results |
número variádico de valores | (C3) |
Restrições
- (C1)
condtem o tipo(T0, ..., TN-1) -> tensor<i1>, em queTi = type(operand[i]). - (C2)
bodytem o tipo(T0, ..., TN-1) -> (T0, ..., TN-1), em queTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Exemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
Semântica
Executa XOR elemento a elemento de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: XOR lógico.
- Para números inteiros: XOR bit a bit.
Entradas
| Rótulo | Nome | Tipo | Restrições |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo booleano ou inteiro | (C1) |
| (I2) | rhs |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
| Nome | Tipo | Restrições |
|---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result).
Exemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
Interoperabilidade de dialetos
No momento, os programas StableHLO em uso às vezes contêm operações que não são definidas por StableHLO.
Módulo, função, chamada e retorno
O StableHLO usa operações MLIR upstream para ModuleOp, FuncOp, CallOp e ReturnOp. Isso foi feito para melhorar a interoperabilidade com a estrutura MLIR atual, já que muitas transmissões úteis são escritas para FuncOp e ModuleOp, e muitos pipelines de compilação esperam que essas operações estejam presentes. Garantias de compatibilidade total são aplicadas a essas operações. Se algo mudar nessas operações de maneira incompatível (por exemplo, remoção), os equivalentes do StableHLO serão adicionados para preservar a compatibilidade.
CHLO
O conjunto de operações CHLO contém operações de nível superior que se decompõem em StableHLO. No momento, não há garantias de compatibilidade para CHLO. Para garantir a compatibilidade, use a transmissão chlo-legalize-to-stablehlo antes da serialização.
Operações de forma
É um caso de uso comum na comunidade usar determinadas operações dos dialetos principais do MLIR em programas dinâmicos do StableHLO para realizar cálculos de forma.
Normalmente, elas incluem operações do dialeto shape, como shape_of ou num_elements, operações do dialeto tensor, como dim ou from_elements, e o tipo index integrado.
A RFC de dinamismo > O2
indica que esses tipos estão fora do escopo, mas algum suporte para tipos index é
incluído para fins de interoperabilidade. Não há garantias de compatibilidade para essas operações ou tipos. A transmissão shape-legalize-to-stablehlo
pode ser usada para converter essas operações em operações StableHLO totalmente compatíveis.
Operações descontinuadas
Há várias operações do StableHLO que foram herdadas do MHLO e estão sendo descontinuadas e removidas do StableHLO. Confira todos os detalhes sobre essas remoções em Limpeza do StableHLO v1.0 #2283. O problema do rastreador para essas descontinuações é #2340.
Essas operações se enquadram em algumas categorias:
- Categoria "Não em HLO" de operações do StableHLO. Inicialmente, elas faziam parte do conjunto de operações do StableHLO, mas depois foram consideradas inadequadas:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Operações não usadas: essas operações podem ter sido úteis em algum momento, mas foram subdesenvolvidas ou os pipelines que as usavam foram refatorados para não precisarem mais delas. Isso inclui
map,tuple(#598),get_tuple_element,rng, comparações decomplex#560, e convoluçãowindow_reversal(#1181).
Algumas dessas operações podem ser removidas facilmente, já que podem ser expressas usando operações atuais (broadcast, create_token, cross-replica-sum, dot, unary_einsum) e serão removidas após o período de compatibilidade atual (seis meses). Outras ainda estão sendo analisadas para remoção (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex comparações, window_reversal). Aguardando o feedback da comunidade, essas operações serão removidas ou adicionadas à especificação com suporte total. Até que esses futuros de operações sejam conhecidos, eles terão apenas seis meses de compatibilidade garantida.
Execução
Execução sequencial
Um programa StableHLO é executado fornecendo valores de entrada à função main e calculando valores de saída. Os valores de saída de uma função são calculados executando o gráfico de operações com raiz na operação return correspondente.
A ordem de execução é definida pela implementação, desde que esteja alinhada ao fluxo de dados, ou seja, se as operações forem executadas antes dos usos. No StableHLO, todas as operações que causam efeitos colaterais consomem e produzem um token (vários tokens podem ser multiplexados em um token usando after_all). Portanto, a ordem de execução dos efeitos colaterais também é alinhada ao fluxo de dados. Por exemplo, no programa abaixo, há duas ordens de execução possíveis: %0 → %1 → %2 → return e %1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Mais formalmente, um processo StableHLO é uma combinação de:
1) um programa StableHLO, 2) status de operação (ainda não executado,
já executado) e 3) valores intermediários em que o processo está trabalhando.
O processo começa com valores de entrada para a função main, passa pelo gráfico de operações que atualizam status de operação e valores intermediários e termina com valores de saída. Mais formalizações serão definidas posteriormente (#484).
Execução paralela
Os programas StableHLO podem ser executados em paralelo, organizados em uma grade de processos 2D de num_replicas por num_partitions, que têm o tipo ui32.
Na grade de processos do StableHLO, num_replicas * num_partitions processos do StableHLO são executados ao mesmo tempo. Cada processo tem um
process_id = (replica_id, partition_id) exclusivo, em que
replica_id em replica_ids = range(num_replicas) e
partition_id em partition_ids = range(num_partitions), ambos com
tipo ui32.
O tamanho da grade de processos é conhecido de forma estática para cada programa. No futuro, planejamos tornar isso uma parte explícita dos programas StableHLO #650, e a posição na grade de processos é conhecida de forma estática para cada processo. Cada processo tem acesso à posição na grade de processos usando as operações replica_id e partition_id.
Na grade de processo, os programas podem ser todos iguais (estilo "Programa único, vários dados"), todos diferentes (estilo "Vários programas, vários dados") ou algo intermediário. No futuro, planejamos incluir suporte para outras formas de definir programas paralelos do StableHLO, incluindo GSPMD (#619).
Na grade de processos, eles são principalmente independentes uns dos outros. Eles têm status de operação, valores de entrada/intermediários/saída separados, e a maioria das operações é executada separadamente entre os processos, com exceção de um pequeno número de operações coletivas descritas abaixo.
Como a execução da maioria das operações usa apenas valores do mesmo processo, geralmente não há ambiguidade ao se referir a esses valores pelos nomes.
No entanto, ao descrever a semântica das operações coletivas, isso é insuficiente, o que dá origem à notação name@process_id para se referir ao valor name em um processo específico. Nessa perspectiva, name não qualificado pode ser visto como uma abreviação de name@(replica_id(), partition_id()).
A ordem de execução entre processos é definida pela implementação, exceto pela sincronização introduzida pela comunicação ponto a ponto e operações coletivas conforme descrito abaixo.
Comunicação ponto a ponto
Os processos do StableHLO podem se comunicar entre si usando canais do StableHLO. Um canal é representado por um ID positivo do tipo
si64. Com várias operações, é possível enviar e receber valores de canais.
Outras formalizações, como a origem desses IDs de canal, como os programas de processo tomam conhecimento deles e que tipo de sincronização é introduzido por eles, ainda estão pendentes (#484).
Comunicação por streaming
Todo processo do StableHLO tem acesso a duas interfaces de streaming:
- Infeed que pode ser lido.
- Outfeed em que é possível gravar.
Ao contrário dos canais, que são usados para comunicação entre processos e, portanto, têm processos nas duas extremidades, as entradas e saídas têm a outra extremidade definida pela implementação.
Outras formalizações, como a influência da comunicação de streaming na ordem de execução e o tipo de sincronização introduzida por ela, ainda estão pendentes (#484).
Operações coletivas
Há seis operações coletivas no StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute e reduce_scatter. Todas essas operações dividem os processos na grade de processos do StableHLO em grupos de processos do StableHLO e executam uma computação conjunta em cada grupo de processos, de forma independente de outros grupos.
Em cada grupo de processos, as operações coletivas podem introduzir uma barreira de sincronização. Mais formalização, por exemplo, detalhando quando exatamente essa sincronização acontece, como exatamente os processos chegam a essa barreira e o que acontece se eles não chegarem, está pendente (#484).
Se o grupo de processos envolver comunicação entre partições, ou seja, houver processos no grupo com IDs de partição diferentes, a execução da operação coletiva vai precisar de um canal, e a operação coletiva vai precisar fornecer um channel_id positivo do tipo si64. A comunicação entre réplicas não precisa de canais.
Os cálculos realizados pelas operações coletivas são específicos para operações individuais e são descritos nas seções de operações individuais acima. No entanto, as estratégias pelas quais a grade de processos é dividida em grupos de processos são compartilhadas entre essas operações e descritas nesta seção. De maneira mais formal, o StableHLO é compatível com as quatro estratégias a seguir.
cross_replica
Somente as comunicações entre réplicas acontecem em cada grupo de processos. Essa estratégia usa replica_groups, uma lista de listas de IDs de réplica, e calcula um produto cartesiano de replica_groups por partition_ids. replica_groups
precisa ter elementos exclusivos e abranger todas as replica_ids. De maneira mais formal, usando a sintaxe do Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para replica_groups = [[0, 1], [2, 3]] e num_partitions = 2,
cross_replica vai gerar
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
Somente as comunicações entre partições acontecem em cada grupo de processos. Essa estratégia usa partition_groups, uma lista de listas de IDs de partição, e calcula um produto cartesiano de partition_groups por replica_ids.
partition_groups precisa ter elementos exclusivos e cobrir todos os partition_ids.
De maneira mais formal, usando a sintaxe do Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para partition_groups = [[0, 1]] e num_replicas = 4,
cross_partition vai gerar
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
As comunicações entre réplicas e entre partições podem ocorrer em cada grupo de processos. Essa estratégia usa replica_groups, uma lista de listas de IDs de réplica, e calcula produtos cartesianos de cada replica_group por partition_ids. replica_groups precisa ter elementos exclusivos e cobrir todos os
replica_ids. De maneira mais formal, usando a sintaxe do Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para replica_groups = [[0, 1], [2, 3]] e num_partitions = 2,
cross_replica_and_partition vai gerar
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Essa estratégia usa flattened_id_groups, uma lista de listas de IDs de processo "simplificados"
no formato replica_id * num_partitions + partition_id, e
os transforma em IDs de processo. flattened_id_groups precisa ter elementos exclusivos
e cobrir todas as process_ids. De maneira mais formal, usando a sintaxe do Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]],
num_replicas = 4 e num_partitions = 2, flattened_ids vai produzir
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Precisão
No momento, o StableHLO não oferece garantias sobre a precisão numérica, mas isso pode mudar no futuro (#1156).
Semântica de execução da operação quantizada
A interpretação das operações quantizadas do StableHLO pode variar dependendo dos requisitos e recursos de hardware. Por exemplo, alguns hardwares podem optar por interpretar operações quantizadas usando uma estratégia de "desquantizar, realizar operação de ponto flutuante e, por fim, quantizar". Outros podem realizar todo o cálculo com aritmética de números inteiros. Consequentemente, a interpretação das operações quantizadas do StableHLO é determinada exclusivamente pela implementação específica. A interpretação da quantização híbrida (#1575) precisa ser baseada na semântica conforme prescrito na especificação (via 1792).
Erros
Os programas StableHLO são validados por um conjunto extenso de restrições para operações individuais, o que elimina muitas classes de erros antes do tempo de execução. No entanto, ainda é possível ter condições de erro, por exemplo, por estouros de números inteiros, acessos fora dos limites etc. A menos que sejam explicitamente indicados, todos esses erros resultam em um comportamento definido pela implementação, mas isso pode mudar no futuro (#1157).
Exceções de ponto flutuante
Como exceção a essa regra, as exceções de ponto flutuante em programas StableHLO têm um comportamento bem definido. Operações que resultam em exceções definidas pelo padrão IEEE-754 (operação inválida, divisão por zero, estouro, estouro negativo ou exceções inexatas) produzem resultados padrão (conforme definido no padrão) e continuam a execução sem gerar a flag de status correspondente, semelhante ao tratamento de exceções raiseNoFlag do padrão. As exceções para operações não padrão (por exemplo, operações aritméticas complexas e determinadas funções transcendentais) são definidas pela implementação.
Incompatibilidades de forma
O StableHLO oferece suporte a tensores com formato dinâmico. No entanto, as formas precisam concordar em tempo de execução. Caso contrário, o comportamento será indefinido. O StableHLO não fornece explicitamente uma operação que possa afirmar que um tensor tem um determinado formato durante a execução. Gerar código correto é responsabilidade do produtor.
Como exemplo específico, o programa abaixo é válido. No entanto, em tempo de execução, as formas exatas de %arg0 e %arg1 precisam ser iguais. Caso contrário, o comportamento do programa é indefinido:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Para descrever a sintaxe, este documento usa a versão ISO modificada da sintaxe EBNF (ISO/IEC 14977:1996, Wikipedia), com duas modificações: 1) as regras são definidas usando ::= em vez de =,
2) A concatenação é expressa usando justaposição em vez de ,.
Para descrever a semântica (ou seja, nas seções "Tipos", "Constantes" e "Operações"), usamos fórmulas baseadas na sintaxe do Python estendida com suporte para expressar operações de matriz de forma concisa, conforme descrito abaixo. Isso funciona bem para pequenos snippets de código, mas, em casos raros em que snippets maiores são necessários, usamos a sintaxe Python padrão, que sempre é apresentada explicitamente.
Fórmulas
Vamos analisar como as fórmulas funcionam com base em um exemplo da especificação dot_general. Uma das restrições para essa operação é a seguinte:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
Os nomes usados nessa fórmula vêm de duas fontes: 1) funções globais, ou seja, dim, 2) definições de membros do elemento de programa correspondente, ou seja, entradas lhs, lhs_batching_dimensions, rhs e rhs_batching_dimensions definidas na seção "Entradas" de dot_general.
Como mencionado acima, a sintaxe dessa fórmula é baseada em Python com algumas extensões orientadas à concisão. Para entender a fórmula, vamos transformá-la em uma sintaxe Python simples.
A) Nessas fórmulas, usamos = para representar a igualdade. Portanto, a primeira etapa para obter a sintaxe do Python é substituir = por ==, da seguinte maneira: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Além disso, essas fórmulas são compatíveis com reticências (...), que transformam expressões escalares em expressões de tensor. Em resumo, f(xs...) significa aproximadamente "para cada escalar x no tensor xs, calcule um escalar f(x) e retorne todos esses resultados escalares juntos como um resultado de tensor". Na sintaxe Python padrão, a fórmula de exemplo se torna: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Graças às reticências, muitas vezes é possível evitar trabalhar no nível de escalares individuais. No entanto, em alguns casos complicados, uma sintaxe semi-informal de nível inferior pode ser usada, como na fórmula start_indices[bi0, ..., :, ..., biN] da especificação gather. Para ser conciso, não fornecemos um formalismo exato para traduzir essa sintaxe para Python simples, na esperança de que ela ainda seja intuitivamente compreensível caso a caso.
Informe se algumas fórmulas específicas parecem opacas, e vamos tentar melhorá-las.
Além disso, você vai notar que as fórmulas usam reticências para expandir todos os tipos de listas, incluindo tensores, listas de tensores (que podem surgir de um número variádico de tensores), etc. Essa é outra área em que não fornecemos um formalismo exato (por exemplo, as listas não fazem parte do sistema de tipos StableHLO) e, em vez disso, confiamos na capacidade de compreensão intuitiva.
C) O último veículo de notação importante que usamos é a transmissão implícita. Embora o conjunto de operações StableHLO não seja compatível com a transmissão implícita, as fórmulas são, também em prol da concisão. Em resumo, se um escalar for usado em um contexto em que um tensor é esperado, o escalar será transmitido para a forma esperada.
Para continuar o exemplo de dot_general, confira outra restrição:
0 <= lhs_batching_dimensions < rank(lhs). Conforme definido na especificação dot_general, lhs_batching_dimensions é um tensor, mas 0 e rank(lhs) são escalares. Depois de aplicar a transmissão implícita, a fórmula vai se tornar [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Quando aplicada a uma operação dot_general específica, essa fórmula resulta em um tensor de booleanos. Quando as fórmulas são usadas como restrições, a restrição é válida se a fórmula resultar em true ou em um tensor que tenha apenas elementos true.
Nomes
Em fórmulas, o escopo léxico inclui: 1) funções globais, 2) definições de membros,
3) definições locais. Confira abaixo a lista de funções globais. A lista de definições de elementos depende do elemento de programa a que a notação é aplicada:
- Para operações, as definições de membros incluem nomes apresentados nas seções "Entradas" e "Saídas".
- Para todo o resto, as definições de membros incluem partes estruturais do
elemento de programa, nomeadas de acordo com os não terminais EBNF correspondentes. Na maioria das vezes, os nomes dessas partes estruturais são obtidos convertendo os nomes dos não terminais para snake case (por exemplo,
IntegerLiteral=>integer_literal), mas às vezes os nomes são abreviados no processo (por exemplo,QuantizationStorageType=>storage_type). Nesse caso, os nomes são introduzidos explicitamente de maneira semelhante às seções "Entradas" / "Saídas" nas especificações de operação. - Além disso, as definições de membros sempre incluem
selfpara se referir ao elemento de programa correspondente.
Valores
Quando as fórmulas são avaliadas, elas funcionam com os seguintes tipos de valores:
1) Value (valores reais, por exemplo, dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; eles sempre sabem os tipos),
2) Placeholder (valores futuros, por exemplo, lhs, rhs ou result; os valores reais ainda não são conhecidos, apenas os tipos),
3) Type (tipos conforme definido na seção "Tipos"),
4) Function (funções globais conforme definido na seção "Funções").
Dependendo do contexto, os nomes podem se referir a valores diferentes. Mais especificamente, a seção "Semântica" para operações (e equivalentes para outros elementos de programa) define a lógica de tempo de execução. Portanto, todas as entradas estão disponíveis como Value.
Em contraste, a seção "Restrições" para operações (e equivalentes) define a lógica de "tempo de compilação", ou seja, algo que normalmente é executado antes do tempo de execução. Assim, apenas entradas constantes estão disponíveis como Value, e outras entradas estão disponíveis apenas como Placeholder.
| Nomes | Em "Semântica" | Em "Restrições" |
|---|---|---|
| Funções globais | Function |
Function |
| Entradas constantes | Value |
Value |
| Entradas não constantes | Value |
Placeholder |
| Saídas | Value |
Placeholder |
| Definições locais | Depende da definição | Depende da definição |
Vamos considerar um exemplo de operação transpose:
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
Para essa operação, permutation é uma constante. Portanto, está disponível como um Value
na semântica e nas restrições. Por outro lado, operand e result estão
disponíveis como um Value na semântica, mas apenas como um Placeholder nas restrições.
Funções
Construção de tipos
Não há funções que possam ser usadas para construir tipos. Em vez disso, usamos diretamente a sintaxe de tipo porque ela geralmente é mais concisa. Por exemplo,
(tensor<E>, tensor<E>) -> (tensor<E>) em vez de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
Funções em tipos
element_typeé definido em tipos de tensores e tipos de tensores quantizados e retorna, respectivamente, a parteTensorElementTypeouQuantizedTensorElementTypedoTensorTypeouQuantizedTensorTypecorrespondente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueé um atalho parais_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueé um atalho parais_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolverifica se o tipoxpode ser promovido para o tipoy. QuandoxeysãoQuantizedTensorElementTypes, a promoção é aplicada apenas aostorage_type. Essa versão específica da promoção é usada no contexto do cálculo de redução. Consulte a RFC para mais detalhes.
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueé um atalho parais_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Disponível para todos os tipos. Por exemplo,is_float(x)retornatruesexfor umFloatType. Sexfor um valor ou marcador de posição, essa função será um atalho parais_type_name(type(x)).max_value(x: Type) -> Valueretorna o valor máximo de umTensorElementType. Sexnão for umTensorElementType, retornaráNone.min_value(x: Type) -> Valueretorna o menor valor possível de umTensorElementType. Sexnão for umTensorElementType, retornaráNone.member_name(x: Value | Placeholder | Type) -> Any. Disponível para todas as definições de membrosmember_namede todos os tipos. Por exemplo,tensor_element_type(x)retorna a parteTensorElementTypede umTensorTypecorrespondente. Sexfor um valor ou marcador de posição, essa função será um atalho paramember_name(type(x)). Sexnão for um tipo que tenha um membro adequado ou um valor ou marcador de posição desse tipo, vai retornarNone.is_empty_algorithm(*args: Type)verifica se todos os campos do algoritmo de ponto estão definidos comoNone. Isso é necessário porque os algoritmos de ponto têm comportamentos padrão definidos pela implementação. Portanto, especificar um valor padrão seria incorreto.
Construção de valores
operation_name(*xs: Value | Type) -> Value. Disponível para todas as operações. Por exemplo,add(lhs, rhs)usa dois valores de tensorlhserhse retorna o resultado da avaliação da operaçãoaddcom essas entradas. Para algumas operações, comobroadcast_in_dim, os tipos de saídas são "de suporte de carga", ou seja, necessários para avaliar uma operação. Nesse caso, a função usa esses tipos como argumentos.
Funções em valores
Todos os operadores e funções do Python estão disponíveis. Por exemplo, as notações de assinatura e segmentação do Python estão disponíveis para indexar tensores, tensores quantizados e tuplas.
to_destination_type(x: Value, destination_type: Type) -> Valueé definido em tensores e retorna o valor convertido dexcom base emtype(x)edestination_typeda seguinte maneira:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Há uma discussão inicial sobre a fusão das operações convert, uniform_quantize e
uniform_dequantize (#1576).
Depois da fusão, não precisamos da função acima e podemos usar o nome da operação
para convert.
is_nan(x: Value) -> Valueé definido em tensores e retornatruese todos os elementos dexforemNaNoufalsede outra forma. Sexnão for um tensor, vai retornarNone.is_sorted(x: Value) -> Valueé definido em tensores e retornatruese os elementos dexforem classificados em ordem crescente em relação à ordem lexicográfica crescente dos índices oufalsecaso contrário. Sexnão for um tensor, retornaráNone.is_unique(x: Value) -> Valueé definido em tensores e retornatruesexnão tiver elementos duplicados oufalsecaso contrário. Sexnão for um tensor, vai retornarNone.member_name(x: Value) -> Anyé definido para todas as definições de membrosmember_namede todos os valores. Por exemplo,real_part(x)retorna a parteRealPartde umComplexConstantcorrespondente. Sexnão for um valor que tenha um membro apropriado, vai retornarNone.same(x: Value) -> Valueé definido em tensores e retornatruese os elementos dexforem todos iguais entre si oufalsecaso contrário. Se o tensor não tiver elementos, isso será considerado "todos iguais entre si", ou seja, a função vai retornartrue. Sexnão for um tensor, retornaráNone.split(x: Value, num_results: Value, axis: Value) -> Valueé definido em tensores e retorna fatiasnum_resultsdexao longo do eixoaxis. Sexnão for um tensor oudim(x, axis) % num_results != 0, retornaráNone.is_defined_in_parent_scope(x: Value) -> Valueé definido em strings e retornatruesexfor o nome de uma função definida no mesmo escopo da função principal da operação relevante.is_namespaced_op_name(x: Value) -> Valueé definido em strings e retornatruesexfor um nome de operação válido, ou seja, respeitar a seguinte expressão regular:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Cálculos de forma
axes(x: Value | Placeholder | Type) -> Valueé um atalho pararange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valueé um atalho parashape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listé um atalho paralist(map(lambda axis: dim(x, axis), axes)).index_space(x: Value | Placeholder | Type) -> Valueé definido em tensores e retorna índicessize(x)para osTensorTypecorrespondentes classificados em ordem lexicográfica crescente, ou seja,[0, ..., 0],[0, ..., 1], ...,shape(x) - 1. Sexnão for um tipo de tensor, um tipo de tensor quantizado, um valor ou um marcador de posição de um desses tipos, vai retornarNone.rank(x: Value | Placeholder | Type) -> Valueé um atalho parasize(shape(x)).shape(x: Value | Placeholder | Type) -> Valueé definido na seção "Funções em tipos" usandomember_name.size(x: Value | Placeholder | Type) -> Valueé um atalho parareduce(lambda x, y: x * y, shape(x)).
Cálculos de quantização
def baseline_element_type(x: Value | Placeholder | Type) -> Typeé um atalho paraelement_type(baseline_type(x)).baseline_typeé definido em tipos de tensores e tipos de tensores quantizados e os transforma em um "valor de referência", ou seja, um tipo com o mesmo formato, mas com os parâmetros de quantização do tipo de elemento redefinidos para os valores padrão. Isso é usado como um truque útil para comparar os tipos de tensores e tensores quantizados de maneira uniforme, o que é necessário com frequência. Para tipos quantizados, isso permite comparar tipos ignorando os parâmetros de quantização. Ou seja,shape,storage_type,expressed_type,storage_min,storage_maxequantization_dimension(para tipo quantizado por eixo) precisam corresponder, masscalesezero pointspodem ser diferentes.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizeé definido em tipos de tensores quantizados e os transforma em tipos de tensores de ponto flutuante. Isso acontece convertendo elementos quantizados, que representam valores inteiros do tipo de armazenamento, em valores de ponto flutuante correspondentes do tipo expresso usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizeé definido em tipos de tensor de ponto flutuante e os transforma em tipos de tensor quantizados. Isso acontece convertendo valores de ponto flutuante do tipo expresso em valores inteiros correspondentes do tipo de armazenamento usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizeé usado para especificar cálculos de elemento a elemento em tensores quantizados. Ele desquantiza, ou seja, transforma elementos quantizados em seus tipos expressos, realiza uma operação e quantiza, ou seja, transforma os resultados de volta em seus tipos de armazenamento. No momento, essa função só funciona para quantização por tensor. A quantização por eixo está em andamento (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opé usado para especificar a quantização somente de peso para op híbrida que aceita lhs em ponto flutuante e rhs em tipos quantizados. Ele desquantiza entradas quantizadas nos tipos expressos e realiza cálculos em ponto flutuante. O tipo de elemento do tensor float lhs e o tipo expresso do tensor rhs quantizado precisam ser idênticos.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Cálculos de grade
cross_partition(replica_groups: Value) -> Value. Consulte a seção "cross_replica" acima.cross_replica(replica_groups: Value) -> Value. Consulte a seção "cross_replica" acima.cross_replica_and_partition(replica_groups: Value) -> Value. Consulte a seção "cross_replica_and_partition" acima.flattened_ids(replica_groups: Value) -> Value. Consulte a seção "flattened_ids" acima.
Dinamismo
Os valores do StableHLO podem ter tamanhos de dimensão dinâmicos, por exemplo, tensor<?xi64>.
No entanto, os valores do StableHLO não podem ter um número dinâmico de dimensões (dinamismo não classificado, por exemplo, tensor<*xi64>). Operandos e resultados podem usar tamanhos de dimensão dinâmicos, mesmo que haja restrições nos tamanhos. As restrições serão verificadas estaticamente, se possível. Caso contrário, serão adiadas para o tempo de execução, e incompatibilidades vão resultar em comportamento indefinido. Consulte os exemplos abaixo.
Incompatibilidades de forma para operações unárias elementares
Considere o seguinte programa de exemplo:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Esse programa é incomum porque não é comum saber o formato do resultado, mas não o formato da entrada. No entanto, esse é um programa StableHLO válido. Não é possível validar estaticamente a operação abs neste programa, porque o formato exato do operando é desconhecido. No entanto, as formas são certamente compatíveis, e isso pode ser verificado de forma estática: ? pode se tornar 2 no tempo de execução, e não haveria problema. No entanto, ? também pode ser algum outro número inteiro, e nesse caso o comportamento é indefinido.
Se o tamanho de uma dimensão for dinâmico no resultado, não poderá haver comportamento indefinido. Na verdade, não há um tamanho "esperado", então não pode haver uma incompatibilidade.
Incompatibilidade de formatos para operações binárias elementares
Considere o seguinte programa de exemplo:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
No caso de operações binárias elemento a elemento, as formas das entradas e do resultado precisam concordar no tempo de execução. No momento da compilação, as dimensões estáticas precisam ser iguais. Caso contrário, elas só precisam ser compatíveis. Se qualquer dimensão for dinâmica nas entradas, poderá haver um comportamento indefinido no tempo de execução, porque o tamanho dinâmico pode não corresponder ao tamanho correspondente no outro operando (estático ou dinâmico). Se todas as entradas forem estáticas, não importa se o resultado é dinâmico ou não: as dimensões conhecidas estaticamente serão verificadas de forma estática, e as dimensões dinâmicas não impõem restrições.
Incompatibilidades de forma para operações que usam a forma de saída como um operando
Considere o seguinte programa de exemplo:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Os valores no operando de forma durante a execução precisam corresponder à forma do resultado. Caso contrário, o comportamento será indefinido. Ou seja, no tempo de execução, %arg0 precisa ter um valor de dense<[3, 4]> : tensor<2xi32>. Se o operando de forma for constante, isso poderá ser verificado de forma estática. Se o formato do resultado for totalmente dinâmico, não poderá haver uma incompatibilidade.