StableHLO é um conjunto para operações de alto nível (HLO, na sigla em inglês) em modelos de machine learning (ML). O StableHLO funciona como uma camada de portabilidade entre diferentes frameworks e compiladores de ML: frameworks de ML que produzem programas StableHLO são compatíveis com compiladores de ML que consomem programas StableHLO.
Nosso objetivo é simplificar e acelerar o desenvolvimento de ML criando mais interoperabilidade entre vários frameworks de ML (como TensorFlow, JAX e PyTorch) e compiladores de ML (como XLA e IREE). Por isso, este documento fornece uma especificação para a linguagem de programação StableHLO.
Essa especificação contém três seções principais. Primeiro, a seção Programas descreve a estrutura dos programas StableHLO, que consistem em funções StableHLO, que consistem em operações do StableHLO. Nessa estrutura, a seção Ops especifica a semântica de operações individuais. A seção Execution oferece semântica para todas essas operações executadas juntas em um programa. Por fim, a seção Notação discute a notação usada ao longo da especificação.
Para ver a especificação de uma versão anterior do StableHLO, abra o repositório na versão com tag relevante. Por exemplo, a especificação StableHLO v0.19.0 (link em inglês). Para conferir as mudanças que ocorreram em cada upgrade de versão secundária do StableHLO, consulte o registro de versão em VhloDialect.td.
Programas
Program ::= {Func}
Os programas StableHLO consistem em um número arbitrário de funções StableHLO.
Veja abaixo um programa de exemplo com uma função @main
que tem três entradas (%image
, %weights
e %bias
) e uma saída. O corpo da função
tem 6 ops.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funções
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
As funções StableHLO (também chamadas de funções nomeadas) têm um identificador, entradas/saídas e um corpo. No futuro, planejamos introduzir mais metadados para funções a fim de melhorar a compatibilidade com o HLO (425, 626, 740 e 744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Os identificadores StableHLO são semelhantes aos de várias linguagens de programação, com duas peculiaridades: 1) todos os identificadores têm identificadores que distinguem diferentes tipos de identificadores; 2) identificadores de valor podem ser completamente numéricos para simplificar a geração de programas StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Os tipos de StableHLO são categorizados em tipos de valor (também chamados de tipos de primeira classe) que representam valores de StableHLO e tipos sem valor, que descrevem outros elementos do programa. Os tipos StableHLO são semelhantes aos tipos em muitas linguagens de programação. A principal peculiaridade é a natureza específica do domínio do StableHLO, que resulta em alguns resultados incomuns (por exemplo, tipos escalares não são tipos de valor).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Os tipos de tensor representam tensores, ou seja, matrizes multidimensionais. Elas têm uma forma e um tipo de elemento, em que uma forma representa tamanhos de dimensão não negativos ou desconhecidos na ordem crescente das dimensões correspondentes (que também são chamadas de eixo) numeradas de 0
a R-1
. O número de dimensões R
é chamado de classificação. Por exemplo, tensor<2x3xf32>
é
um tipo de tensor com a forma 2x3
e o tipo de elemento f32
. Ele tem duas dimensões (ou, em outras palavras, dois eixos): a 0a dimensão e a 1a dimensão, com tamanhos 2 e 3. A classificação é 2.
As formas podem ser parcial ou completamente desconhecidas (dinâmicas). Por exemplo, tensor<?x2xf64>
é parcialmente desconhecido e tensor<?x?xf64>
é completamente desconhecido. Os tamanhos de dimensão dinâmica são representados por um ?
. As formas não podem ser desclassificadas.
No futuro, planejamos explorar a extensão dos tipos de tensores para além dos tamanhos de dimensão e tipos de elementos, por exemplo, para incluir layouts (#629) e esparsidade (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nome | Tipo | Restrições |
---|---|---|
storage_type |
tipo de número inteiro | (C1 a C3) e (C8) |
storage_min |
constante de número inteiro | (C1), (C3) e (C7) |
storage_max |
constante de número inteiro | (C2), (C3) e (C7) |
expressed_type |
tipo de ponto flutuante | (C4) |
quantization_dimension |
constante de número inteiro opcional | (C10 a C12) |
scales |
número variado de constantes de ponto flutuante | (C4 a C6), (C9), (C10) (C13) |
zero_points |
número variável de constantes inteiras | (C7 a C9) |
Os tipos de elementos quantizados representam valores inteiros de um tipo de armazenamento no
intervalo de storage_min
a storage_max
(inclusive) que correspondem a
valores de ponto flutuante de um tipo expresso. Para um determinado valor inteiro i
,
o valor de ponto flutuante correspondente f
pode ser calculado como
f = (i - zero_point) * scale
, em que scale
e zero_point
são chamados de
parâmetros de quantização. storage_min
e storage_max
são opcionais na gramática, mas têm valores padrão min_value(storage_type)
e max_value(storage_type)
, respectivamente. Os tipos de elementos quantizados têm as
seguintes restrições:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Se for
is_empty(quantization_dimension)
, entãosize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
No momento, QuantizationScale
é uma constante de ponto flutuante, mas há
um grande interesse em escalas baseadas em números inteiros, representadas com multiplicadores e
mudanças. Planejamos investigar isso em breve
(#1404).
Há uma discussão em andamento sobre a semântica de QuantizationZeroPoint
,
incluindo o tipo, os valores e se pode haver apenas um ou
potencialmente vários pontos zero em um tipo de tensor quantizado. Com base nos
resultados dessa discussão, a especificação sobre pontos zero pode mudar
no futuro (#1405).
Outra discussão em andamento envolve a semântica de QuantizationStorageMin
e QuantizationStorageMax
para determinar se alguma restrição precisa ser
imposta a esses valores e aos valores dos tensores quantizados
(#1406).
Por fim, planejamos explorar a representação de escalas desconhecidas e pontos zero, da mesma forma que planejamos explorar a representação de tamanhos de dimensão desconhecidos (#1407).
Os tipos de tensores quantizados representam tensores com elementos quantizados. Esses tensores são exatamente os mesmos que os regulares, exceto pelo fato de que os elementos deles têm tipos de elemento quantizados, em vez de tipos de elemento regulares.
Em tensores quantizados, ela pode ser por tensor, ou seja, ter
um scale
e zero_point
para o tensor inteiro ou pode ser por eixo,
ou seja, com vários scales
e zero_points
, um par por fração de
uma dimensão específica de quantization_dimension
. Mais formalmente, em um tensor t
com quantização por eixo, há frações dim(t, quantization_dimension)
do quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
etc. Todos os elementos na i
a fração usam scales[i]
e zero_points[i]
como parâmetros de quantização. Os tipos de tensores quantizados têm as seguintes
restrições:
- Para quantização por tensor:
- Nenhuma restrição adicional.
- Para quantização por eixo:
- (C13)
quantization_dimension < rank(self)
. - (C14)
dim(self, quantization_dimension) = size(scales)
.
- (C13)
TokenType ::= 'token'
Os tipos de token representam tokens, ou seja, valores opacos produzidos e consumidos por algumas operações. Os tokens são usados para impor uma ordem de execução em operações, conforme descrito na seção Execução.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Os tipos de tuplas representam tuplas, ou seja, listas heterogêneas. As tuplas são um recurso
legado que existe apenas para compatibilidade com HLO. No HLO, as tuplas são usadas para representar entradas e saídas variadas. No StableHLO, entradas e
saídas variadas têm suporte nativo, e o único uso de tuplas no StableHLO é
representar a HLO ABI de maneira abrangente, em que, por exemplo, T
, tuple<T>
e
tuple<tuple<T>>
podem ser significativamente diferentes, dependendo de uma
implementação específica. No futuro, planejamos fazer mudanças na HLO ABI
que podem permitir a remoção de tipos de tupla do StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Os tipos de elementos representam elementos dos tipos de tensor. Ao contrário de muitas linguagens de
programação, esses tipos não são de primeira classe no StableHLO. Isso significa que os programas StableHLO não podem representar diretamente valores desses tipos. Como resultado, é idiomático representar valores escalares do tipo T
com valores de tensores 0-dimensionais do tipo tensor<T>
.
- O tipo booleano representa os valores booleanos
true
efalse
. - Os tipos de números inteiros podem ser assinados (
si
) ou não (ui
) e ter uma das larguras de bits aceitas (4
,8
,16
,32
ou64
). Os tipossiN
assinados representam valores inteiros de-2^(N-1)
a2^(N-1)-1
, incluídos, e os tiposuiN
não assinados representam valores inteiros de0
a2^N-1
. - Os tipos de ponto flutuante podem ser um dos seguintes:
- Tipos
f8E4M3FN
ef8E5M2
correspondentes, respectivamente, às codificaçõesE4M3
eE5M2
do formato FP8 descrito em Formatos FP8 para aprendizado profundo. - Tipos
f8E4M3FNUZ
ef8E5M2FNUZ
correspondentes às codificaçõesE4M3
eE5M2
dos formatos FP8 descritos em Formatos numéricos de 8 bits para redes neurais profundas. - Tipo
f8E4M3B11FNUZ
correspondente à codificaçãoE4M3
dos formatos FP8 descritos em Treinamento e inferência de ponto flutuante de 8 bits híbridos (HFP8) para redes neurais profundas (links em inglês). - Tipo
bf16
correspondente ao formatobfloat16
descrito em BFloat16: o segredo do alto desempenho em Cloud TPUs (link em inglês). - Os tipos
f16
,f32
ef64
que correspondem, respectivamente, aos formatosbinary16
("meia precisão"),binary32
("precisão única") ebinary64
("precisão dupla") descritos no padrão IEEE 754.
- Tipos
- Tipos complexos representam valores complexos que têm uma parte real e uma parte imaginária do mesmo tipo de elemento. Os tipos complexos com suporte são
complex<f32>
(as duas partes são do tipof32
) ecomplex<f64>
(as duas partes são do tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Os tipos de função representam funções nomeadas e anônimas. Elas têm tipos de entrada (a lista de tipos no lado esquerdo de ->
) e de saída (a lista de tipos no lado direito de ->
). Em muitas linguagens de programação, os tipos de função são de primeira classe, mas não em StableHLO.
StringType ::= 'string'
String type representa sequências de bytes. Ao contrário de várias linguagens de programação, o tipo de string não é de primeira classe em StableHLO e é usado apenas para especificar metadados estáticos para elementos do programa.
Operações
As operações StableHLO (também chamadas de ops) representam um conjunto fechado de operações de alto nível em modelos de machine learning. Como discutido acima, a sintaxe do StableHLO é muito inspirada no MLIR, que não é necessariamente a alternativa mais ergonômica, mas é provavelmente a melhor opção para o objetivo do StableHLO de criar mais interoperabilidade entre os frameworks e os compiladores de ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
As operações StableHLO (também chamadas de ops) têm um nome,
entradas/saídas e uma assinatura. O nome consiste no prefixo stablehlo.
e um mnemônico que identifica exclusivamente uma das operações compatíveis. Veja abaixo uma lista abrangente de todas as operações compatíveis.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
As operações consomem entradas e produzem saídas. As entradas são categorizadas em
valores de entrada (calculados durante a execução), funções de entrada (fornecidos
estaticamente, porque as funções de StableHLO não são valores de primeira classe) e
atributos de entrada (também fornecidos estaticamente). O tipo de entradas e saídas
consumidas e produzidas por uma operação depende da mnemônica dela. Por exemplo, a operação add
consome dois valores de entrada e produz um valor de saída. Em comparação, a
operação select_and_scatter
consome três valores de entrada, duas funções de entrada e
três atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
As funções de entrada (que também são chamadas de funções anônimas) são muito
semelhantes às funções nomeadas, exceto pelo seguinte: 1) elas não têm um identificador (por isso
o nome "anônimo") e 2) não declaram tipos de saída (os tipos de saída são
inferidos da operação return
dentro da função).
A sintaxe das funções de entrada inclui uma parte não utilizada atualmente (consulte a produção de Unused
acima), que existe para compatibilidade com MLIR. Em MLIR,
há um conceito mais geral de "regiões" que pode ter vários "blocos"
de operações conectados por operações de salto. Esses blocos têm IDs que correspondem
à produção de Unused
, para que possam ser diferenciados entre si.
O StableHLO não tem operações de salto, portanto, a parte correspondente da sintaxe MLIR não é
usada (mas ainda está presente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Os atributos de entrada têm um nome e um valor que é uma das constantes
compatíveis. Eles são a principal maneira de especificar metadados estáticos para elementos
do programa. Por exemplo, a op concatenate
usa o atributo dimension
para
especificar a dimensão com que os valores de entrada são concatenados. Da mesma forma,
a op slice
usa vários atributos, como start_indices
e limit_indices
,
para especificar os limites usados para dividir o valor de entrada.
No momento, os programas StableHLO no modo geral às vezes contêm atributos que não são descritos neste documento. No futuro, planejamos absorver esses atributos na opset do StableHLO ou proibir que eles apareçam nos programas do StableHLO. Enquanto isso, confira a lista desses atributos:
layout
(629, link em inglês).mhlo.frontend_attributes
(628, link em inglês).mhlo.sharding
(619, link em inglês).output_operand_aliases
(#740).- Metadados de local (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
A assinatura de operações consiste nos tipos de todos os valores de entrada (a lista de tipos no
lado esquerdo de ->
) e nos tipos de todos os valores de saída (a lista de
tipos no lado direito de ->
). Estritamente, os tipos de entrada são
redundantes, e os tipos de saída também são quase sempre redundantes (porque, para
a maioria das operações StableHLO, os tipos de saída podem ser inferidos a partir das entradas). No entanto, a assinatura da operação faz parte deliberadamente da sintaxe do StableHLO para compatibilidade com o MLIR.
Confira abaixo um exemplo de operação com mnemônico select_and_scatter
. Ela consome três
valores de entrada (%operand
, %source
e %init_value
), duas funções de entrada
e três atributos de entrada (window_dimensions
, window_strides
e padding
).
A assinatura da operação inclui apenas os tipos de valores de entrada,
mas não os tipos de funções e atributos de entrada fornecidos inline.
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
As constantes StableHLO têm um literal e um tipo que, juntos, representam
um valor de StableHLO. Geralmente, o tipo faz parte da sintaxe constante, exceto
quando não é ambíguo. Por exemplo, uma constante booleana não ambígua tem o tipo i1
,
enquanto uma constante de número inteiro pode ter vários tipos possíveis.
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Constantes booleanas representam os valores booleanos true
e false
. Constantes booleanas têm tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Constantes inteiras representam valores inteiros por strings que usam notação decimal ou hexadecimal. Outras bases, como binária ou octal, não são aceitas. As constantes inteiras têm as seguintes restrições:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Constantes de ponto flutuante representam valores de ponto flutuante via strings que usam notação decimal ou científica. Além disso, a notação hexadecimal pode ser usada para especificar diretamente os bits no formato de ponto flutuante do tipo correspondente. As constantes de ponto flutuante têm as seguintes restrições:
- (C1) Se a notação não hexadecimal for usada,
is_wellformed(float_literal, float_type)
. - (C2) Se a notação hexadecimal for usada,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Constantes complexas representam valores complexos usando listas de uma parte real
(vem primeiro) e uma parte imaginária (vem depois). Por exemplo,
(1.0, 0.0) : complex<f32>
representa 1.0 + 0.0i
e
(0.0, 1.0) : complex<f32>
representa 0.0 + 1.0i
. A ordem em que essas partes são armazenadas na memória é definida pela implementação. Constantes complexas
têm as seguintes restrições:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
As constantes de tensor representam valores de tensores que usam listas aninhadas especificadas usando a notação NumPy. Por exemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
representa um valor do tensor com o seguinte mapeamento de índices para elementos:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. A ordem em que esses elementos são armazenados na memória é
definida pela implementação. As constantes do tensor têm as seguintes restrições:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, em que:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, em que:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- Caso contrário,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
As constantes de tensor quantizadas representam os valores de tensores quantizados usando a mesma notação que as constantes de tensor, com elementos especificados como constantes do seu tipo de armazenamento. As constantes de tensores quantizadas têm as seguintes restrições:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Os literais de string consistem em bytes especificados com caracteres ASCII e
sequências de escape. Eles não dependem de codificação, portanto, a interpretação desses bytes é definida pela implementação. Os literais de string têm o tipo string
.
Ops
abs
Semântica
Executa uma operação abs com elementos no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros com sinal: módulo inteiro.
- Para pontos flutuantes:
abs
do IEEE-754. - Para números complexos: módulo complexo.
- Para tipos quantizados:
dequantize_op_quantize(abs, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1 a C2) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro assinado ou tipo de ponto flutuante ou tensor quantizado por tensor | (C1 a C2) |
Restrições
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
é definido como:complex_element_type(element_type(operand))
se foris_complex(operand)
.- Caso contrário,
baseline_element_type(operand)
.
Exemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
adicionar
Semântica
Executa a adição elemento de dois tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: adição de números inteiros.
- Para pontos flutuantes:
addition
do IEEE-754. - Para números complexos: adição complexa.
- Para tipos quantizados:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou quantizado | (C1 a C6) |
(I2) | rhs |
tensor ou quantizado | (C1 a C5) (C7) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1 a C7) |
Restrições
- Se a operação usar tensores não quantizados:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Se a operação usar tensores quantizados:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Se for
is_per_axis_quantized(lhs)
, entãoquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Se for
is_per_axis_quantized(rhs)
, entãoquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semântica
Garante que as operações que produzem o inputs
sejam executadas antes de qualquer
operação que dependa de result
. A execução dessa operação não faz nada,
ela só existe para estabelecer dependências de dados de result
a inputs
.
Entradas
Rótulo | Nome | Tipo |
---|---|---|
(I1) | inputs |
número variável de token |
Saídas
Nome | Tipo |
---|---|
result |
token |
Exemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semântica
Dentro de cada grupo de processos na grade de processos do StableHLO, os valores
do tensor operand
de cada processo são concatenados com all_gather_dim
e um
tensor result
é produzido.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Depois, em cada process_group
:
operands@receiver = [operand@sender for sender in process_group]
para todos osreceiver
emprocess_group
.result@process = concatenate(operands@process, all_gather_dim)
para todos osprocess
emprocess_group
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1) e (C6). |
(I2) | all_gather_dim |
constante do tipo si64 |
(C1) e (C6). |
(I3) | replica_groups |
Constante do tensor bidimensional do tipo si64 |
(C2 a C4) |
(I4) | channel_id |
constante do tipo si64 |
(C5) |
(I5) | use_global_device_ids |
constante do tipo i1 |
(C5) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C6) |
Restrições
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
é definido como:num_replicas
secross_replica
for usado.num_replicas
secross_replica_and_partition
for usado.num_processes
seflattened_ids
for usado.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Se for
use_global_device_ids = true
, entãochannel_id > 0
. - (C6)
type(result) = type(operand)
, exceto:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Semântica
Dentro de cada grupo de processos na grade de processos do StableHLO, é aplicada uma função
de redução computation
aos valores do tensor operand
de cada processo
e produz um tensor result
.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Depois, em cada process_group
:
result@process[result_index] = exec(schedule)
para alguma árvore bináriaschedule
em que:exec(node)
=computation(exec(node.left), exec(node.right))
exec(leaf)
=leaf.value
schedule
é uma árvore binária definida pela implementação cuja travessia em ordem éto_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C5) e (C6) |
(I2) | replica_groups |
número variado de constantes de tensor unidimensionais do tipo si64 |
(C1 a C3) |
(I3) | channel_id |
constante do tipo si64 |
(C4) |
(I4) | use_global_device_ids |
constante do tipo i1 |
(C4) |
(I5) | computation |
função | (C5) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C6 a C7) |
Restrições
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
é definido como:num_replicas
secross_replica
for usado.num_replicas
secross_replica_and_partition
for usado.num_processes
seflattened_ids
for usado.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Se for
use_global_device_ids = true
, entãochannel_id > 0
. - (C5)
computation
tem o tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, em queis_promotable(element_type(operand), E)
. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<4xi64>) -> tensor<4xi64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Semântica
Dentro de cada grupo de processos na grade de processos do StableHLO, divide os valores do
tensor operand
ao longo de split_dimension
em partes, dispersa as partes
divididas entre os processos, concatena as partes dispersas com
concat_dimension
e produz um tensor result
.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(replica_groups)
se forchannel_id <= 0
.cross_partition(replica_groups)
se forchannel_id > 0
.
Depois, em cada process_group
:
split_parts@sender = split(operand@sender, split_count, split_dimension)
para todos ossender
emprocess_group
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
, em quereceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1 a C3) e (C9). |
(I2) | split_dimension |
constante do tipo si64 |
(C1), (C2) (C9) |
(I3) | concat_dimension |
constante do tipo si64 |
(C3) e (C9). |
(I4) | split_count |
constante do tipo si64 |
(C2), (C4), (C8) (C9) |
(I5) | replica_groups |
Constante do tensor bidimensional do tipo si64 |
(C5 a C8) |
(I6) | channel_id |
constante do tipo si64 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C9) |
Restrições
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
é definido como:num_replicas
secross_replica
for usado.num_partitions
secross_partition
for usado.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
, exceto:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
e
Semântica
Executa o elemento AND de dois tensores lhs
e rhs
e produz um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: AND bit a bit.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano ou inteiro | (C1) |
(I2) | rhs |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semântica
Executa a operação atan2 com elementos nos tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
atan2
do IEEE-754. - Para números complexos: atan2 complexo.
- Para tipos quantizados:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semântica
Calcula gradientes de várias entradas de retropropagação de batch_norm_training
de grad_output
e produz tensores grad_operand
, grad_scale
e grad_offset
. Mais formalmente, essa operação pode ser expressa como uma decomposição para
operações StableHLO existentes usando a sintaxe do Python, da seguinte maneira:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para tipos quantizados, executa
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1 a C3) e C5. |
(I2) | scale |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2), (C4) (C5) |
(I3) | mean |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C4). |
(I4) | variance |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C4). |
(I5) | grad_output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2) e (C3). |
(I6) | epsilon |
constante do tipo f32 |
|
(I7) | feature_index |
constante do tipo si64 |
(C1) e (C5). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
grad_operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2) e (C3). |
grad_scale |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C4). |
grad_offset |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C4). |
Restrições
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
egrad_offset
têm o mesmobaseline_element_type
. - (C3)
operand
,grad_output
egrad_operand
têm a mesma forma. - (C4)
scale
,mean
,variance
,grad_scale
egrad_offset
têm a mesma forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semântica
Normaliza o tensor operand
em todas as dimensões, exceto a
dimensão feature_index
, e produz um tensor result
. Mais formalmente, essa
operação pode ser expressa como uma decomposição para operações StableHLO atuais
usando a sintaxe do Python, da seguinte maneira:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para tipos quantizados, executa
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1 a C7) |
(I2) | scale |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C3). |
(I3) | offset |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C4). |
(I4) | mean |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C5) |
(I5) | variance |
Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor | (C2) e (C6) |
(I6) | epsilon |
constante do tipo f32 |
|
(I7) | feature_index |
constante do tipo si64 |
(C1), (C3 a C6) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C2) e (C7). |
Restrições
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
eresult
têm o mesmobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semântica
Calcula a média e a variância em todas as dimensões, exceto a dimensão feature_index
, e normaliza o tensor operand
que produz os tensores output
, batch_mean
e batch_var
. Mais formalmente, essa operação pode ser expressa como uma
decomposição para operações StableHLO existentes usando a sintaxe do Python, da seguinte
maneira:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para tipos quantizados, executa
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
(I2) | scale |
Tensor unidimensional de ponto flutuante ou por tensor quantizado | (C2) e (C3). |
(I3) | offset |
Tensor unidimensional de ponto flutuante ou por tensor quantizado | (C2) e (C4). |
(I4) | epsilon |
constante do tipo f32 |
(C1), (C3 a C6) |
(I5) | feature_index |
constante do tipo si64 |
(C1), (C3 a C6) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C7) |
batch_mean |
Tensor unidimensional de ponto flutuante ou por tensor quantizado | (C2) e (C5). |
batch_var |
Tensor unidimensional de ponto flutuante ou por tensor quantizado | (C2) e (C6) |
Restrições
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
eoutput
têm o mesmobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Exemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semântica
Executa uma operação de bitcast no tensor operand
e produz um tensor result
,
em que os bits de todo o tensor operand
são reinterpretados usando o
tipo de tensor result
.
Mais formalmente, considerando E = element_type(operand)
, E' = element_type(result)
e R = rank(operand)
:
- Se for
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Se for
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Se for
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
retorna a representação na memória de um determinado valor, e o comportamento
é definido pela implementação, porque a representação exata dos tensores é
definida pela implementação, assim como a representação exata dos tipos de elementos.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C2) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1 a C2) |
Restrições
- (C1) Com
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
eR = rank(operand)
:- Se for
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Se
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
para todos os0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Se
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
para todos os0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Se for
- (C2) Se for
is_complex(operand) or is_complex(result)
, entãois_complex(operand) and is_complex(result)
.
Exemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semântica
Expande as dimensões e/ou a classificação de um tensor de entrada duplicando os dados
no tensor operand
e produz um tensor result
. Mais formalmente,
result[result_index] = operand[operand_index]
, em que todos os d
em
axes(operand)
:
operand_index[d] = 0
se fordim(operand, d) = 1
.- Caso contrário,
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C2) e C5 a C6. |
(I2) | broadcast_dimensions |
constante do tensor unidimensional do tipo si64 |
(C2 a C6) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1), (C3) e (C5 a C6) |
Restrições
- (C1)
element_type(result)
é fornecido por:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, exceto quequantization_dimension(operand)
,scales(operand)
ezero_points(operand)
podem ser diferentes da resp.quantization_dimension(result)
,scales(result)
ezero_points(result)
. Caso contrário.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Para todos os
d
emaxes(operand)
:dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se for
dim(operand, quantization_dimension(operand)) = 1
, entãoscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Exemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semântica
Produz a saída da execução de exatamente uma função de branches
,
dependendo do valor de index
. Mais formalmente, result = selected_branch()
,
em que:
selected_branch = branches[index]
se for0 <= index < size(branches)
.- Caso contrário,
selected_branch = branches[-1]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | index |
Tensor 0 dimensional do tipo si32 |
|
(I2) | branches |
número variado de funções | (C1 a C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores, tensores quantizados ou tokens | (C4) |
Restrições
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Exemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
Semântica
Executa uma operação de raiz cúbica por elemento no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
rootn(x, 3)
do IEEE-754. - Para números complexos: raiz cúbica complexa.
- Para tipos quantizados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semântica
Executa o ceil com elementos do tensor operand
e produz um tensor result
.
Implementa a operação roundToIntegralTowardPositive
da especificação
IEEE-754. Para tipos quantizados, executa
dequantize_op_quantize(ceil, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semântica
Calcula a decomposição de Cholesky de um lote de matrizes.
Mais formalmente, para todos os i
em index_space(result)
,
result[i0, ..., iR-3, :, :]
é uma decomposição de Cholesky de
a[i0, ..., iR-3, :, :]
, na forma de uma matriz triangular inferior
(se lower
for true
) ou triangular superior (se lower
for false
).
Os valores de saída no triângulo oposto, ou seja, o triângulo superior restrito ou
o triângulo inferior restrito correspondente, são definidos pela implementação.
Se houver i
em que a matriz de entrada não é uma matriz positiva hermitiana definida, o comportamento é indefinido.
Para tipos quantizados, executa
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | a |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1 a C3) |
(I2) | lower |
constante do tensor 0 dimensional do tipo i1 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Exemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limitar
Semântica
Fixa todos os elementos do tensor operand
entre um valor mínimo e um máximo
e produz um tensor result
. Mais formalmente, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
em que min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Para tipos quantizados,
executa dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Impor uma ordenação em números complexos envolve uma semântica surpreendente. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | min |
tensor ou tensor quantizado por tensor | (C1) e (C3). |
(I2) | operand |
tensor ou tensor quantizado por tensor | (C1 a C4) |
(I3) | max |
tensor ou tensor quantizado por tensor | (C2) e (C3). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C4) |
Restrições
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semântica
Dentro de cada grupo de processos na grade de processos do StableHLO, envie o valor do tensor operand
do processo de origem para os processos de destino e produza um tensor result
.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(replica_groups)
se forchannel_id <= 0
.cross_partition(replica_groups)
se forchannel_id > 0
.
Depois disso, result@process
é fornecido por:
operand@process_groups[i, 0]
se houver umi
de modo que o processo esteja emprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
caso contrário.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C3) |
(I2) | replica_groups |
número variado de constantes de tensor unidimensionais do tipo si64 |
(C1) e (C2). |
(I3) | channel_id |
constante do tipo si64 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3) |
Restrições
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, em queN
é definido como:num_replicas
secross_replica
for usado.num_partitions
secross_partition
for usado.
- (C3)
type(result) = type(operand)
.
Exemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semântica
Dentro de cada grupo de processos na grade de processos do StableHLO, envia o valor do tensor operand
do processo de origem para o processo de destino e produz um tensor result
.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(source_target_pairs)
se forchannel_id <= 0
.cross_partition(source_target_pairs)
se forchannel_id > 0
.
Depois disso, result@process
é fornecido por:
operand@process_groups[i, 0]
, se existir umi
de tal forma queprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
caso contrário.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C5) |
(I2) | source_target_pairs |
Constante do tensor bidimensional do tipo si64 |
(C1 a C4) |
(I3) | channel_id |
constante do tipo si64 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, em queN
é definido como:num_replicas
secross_replica
for usado.num_partitions
secross_partition
for usado.
- (C5)
type(result) = type(operand)
.
Exemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Semântica
Executa comparação em elementos dos tensores lhs
e rhs
de acordo com
comparison_direction
e compare_type
e produz um tensor result
.
Os valores de comparison_direction
e compare_type
têm a seguinte
semântica:
Para tipos de elementos booleanos e inteiros:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Para tipos de elementos de ponto flutuante com compare_type = FLOAT
, a operação implementa
as seguintes operações IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Para tipos de elementos de ponto flutuante com compare_type = TOTALORDER
, a operação usa a combinação de operações totalOrder
e compareQuietEqual
do IEEE-754.
Para tipos de elementos complexos, a comparação lexicográfica de pares (real, imag)
é
realizada usando os comparison_direction
e compare_type
fornecidos.
Impor uma ordenação em números complexos envolve semântica inesperada.
Portanto, no futuro, planejamos remover o suporte a números complexos
quando comparison_direction
for GE
, GT
, LE
ou LT
(560).
Para tipos quantizados, executa dequantize_compare(lhs, rhs,
comparison_direction)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1 a C3) |
(I2) | rhs |
tensor ou tensor quantizado por tensor | (C1 a C2) |
(I3) | comparison_direction |
enumeração de EQ , NE , GE , GT , LE e LT |
|
(I4) | compare_type |
enumeração de FLOAT , TOTALORDER , SIGNED e UNSIGNED |
(C3) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo booleano | (C2) |
Restrições
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
é definido como:SIGNED
se foris_signed_integer(element_type(lhs))
.UNSIGNED
se foris_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
ouTOTALORDER
seis_float(element_type(lhs))
.FLOAT
se foris_complex(element_type(lhs))
.
Exemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complexo
Semântica
Executa a conversão elemento para um valor complexo a partir de um par de valores reais e
imaginários, lhs
e rhs
, e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor do tipo f32 ou f64 |
(C1 a C3) |
(I2) | rhs |
tensor do tipo f32 ou f64 |
(C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo complexo | (C2) e (C3). |
Restrições
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
tem o tipocomplex<E>
, em queE = element_type(lhs)
.
Exemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
composto
Semântica
Encapsula uma operação composta (composta) de outras operações do StableHLO,
usando inputs
e composite_attributes
e produzindo results
. A
semântica da operação é implementada pelo atributo decomposition
. A
op composite
pode ser substituída pela própria decomposição sem mudar a semântica
do programa. Nos casos em que a decomposição in-line não fornece a mesma
semântica de operação, prefira usar custom_call
.
O campo version
(o padrão é 0
) é usado para indicar quando a
semântica de um composto muda.
Entradas
Rótulo | Nome | Tipo |
---|---|---|
(I1) | inputs |
número variável de valores |
(I2) | name |
constante do tipo string |
(I3) | composite_attributes |
dicionário de atributos |
(I4) | decomposition |
constante do tipo string |
(I5) | version |
constante do tipo si32 |
Saídas
Nome | Tipo |
---|---|
results |
número variável de valores |
Restrições
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Exemplos
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semântica
Concatena inputs
com a dimensão dimension
na mesma ordem que os argumentos fornecidos e produz um tensor result
. Mais formalmente,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, em que:
id = d0 + ... + dk-1 + kd
.d
é igual adimension
, ed0
, ... sãod
a dimensão deinputs
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1 a C6) |
(I2) | dimension |
constante do tipo si64 |
(C2), (C4) e (C6) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5 a C6) |
Restrições
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, excetodim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, exceto:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Exemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semântica
Produz um tensor output
com base em uma constante value
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | value |
constante | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
output |
tensor ou quantizado | (C1) |
Restrições
- (C1)
type(value) = type(output)
.
Exemplos
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
fazer uma conversão
Semântica
Executa uma conversão por elemento de um tipo de elemento para outro no
tensor operand
e produz um tensor result
.
Em conversões de boolean-to-any-supported-type, o valor false
é convertido em zero, e o valor true
é convertido em um. Para
any-supported-type-to-boolean, um valor zero é convertido em
false
, e valores diferentes de zero são convertidos em true
. Veja abaixo como isso
funciona para tipos complexos.
Para conversões que envolvem número inteiro para inteiro, número inteiro para ponto flutuante ou ponto flutuante para o ponto flutuante, se o valor de origem puder ser representado exatamente no tipo de destino, o valor do resultado será essa representação exata. Caso contrário, o comportamento será definido (#180).
Para conversões que envolvem floating-point-to-integer, a parte fracionária é truncada. Se o valor truncado não puder ser representado no tipo de destino, o comportamento será definido a ser definido (#180).
A conversão que envolve complexo para complexo segue o mesmo comportamento das conversões de ponto flutuante para ponto flutuante para a conversão de partes reais e imaginárias.
Nas conversões de complex-to-any-other-type e complex-to-any-other-type, o valor imaginário de origem é ignorado ou o valor imaginário de destino é zerado, respectivamente. A conversão da parte real segue as conversões de ponto flutuante.
Em princípio, essa operação poderia expressar desquantização (conversão de tensores quantizados em tensores regulares), quantização (conversão de tensores regulares em tensores quantizados) e requantização (conversão entre tensores quantizados), mas no momento temos operações dedicadas para isso, uniform_dequantize
para o primeiro caso de uso e uniform_quantize
para o segundo caso de uso. No futuro, essas duas operações podem ser mescladas
em convert
(#1576).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
Tensor | (C1) |
Restrições
- (C1)
shape(operand) = shape(result)
.
Exemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolução
Semântica
Calcula produtos escalares entre janelas de lhs
e frações de rhs
e produz
result
. O diagrama a seguir mostra como os elementos em result
são calculados de
lhs
e rhs
usando um exemplo concreto.
Mais formalmente, considere a seguinte reformulação das entradas em termos de lhs
para poder expressar janelas de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Esse reenquadramento usa as seguintes funções auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, em quej[d] = i[permutation[d]]
.
Se for feature_group_count = 1
e batch_group_count = 1
, para todos os
output_spatial_index
em index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
em que:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Esse recurso parece não ser usado. Por isso, planejamos removê-lo no futuro (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Se feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Se batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
Para tipos quantizados, executa dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Para tipos quantizados híbridos, executa hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32) (C34) |
(I2) | rhs |
tensor ou quantizado | (C1), (C14 a C16), (C25), (C27 a C29) (C31 a C34) |
(I3) | window_strides |
constante do tensor unidimensional do tipo si64 |
(C2 a C3) (C25) |
(I4) | padding |
Constante do tensor bidimensional do tipo si64 |
(C4) e (C25). |
(I5) | lhs_dilation |
constante do tensor unidimensional do tipo si64 |
(C5 a C6) e (C25) |
(I6) | rhs_dilation |
constante do tensor unidimensional do tipo si64 |
(C7 a C8) e C25. |
(I7) | window_reversal |
constante do tensor unidimensional do tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante do tipo si64 |
(C10), (C13) e (C25) |
(I9) | input_feature_dimension |
constante do tipo si64 |
(C11) (C13 a C14). |
(I10) | input_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C12), (C13) e (C25) |
(I11) | kernel_input_feature_dimension |
constante do tipo si64 |
(C14) e (C18). |
(I12) | kernel_output_feature_dimension |
constante do tipo si64 |
(C15 a C16), (C18), (C25) (C29) |
(I13) | kernel_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C17 a C18) (C25) |
(I14) | output_batch_dimension |
constante do tipo si64 |
(C20) e (C25) |
(I15) | output_feature_dimension |
constante do tipo si64 |
(C20), (C25) (C30) |
(I16) | output_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C19 a C20) (C25) |
(I17) | feature_group_count |
constante do tipo si64 |
(C11), (C14), (C16), (C21) (C23) |
(I18) | batch_group_count |
constante do tipo si64 |
(C10), (C15), (C22), (C23) (C25) |
(I19) | precision_config |
número variável de enumerações de DEFAULT , HIGH e HIGHEST |
(C24) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C25 a C28), (C30) (C32 a 34) |
Restrições
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Considerando
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Considerando
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Considerando
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
é definido como:dim(lhs, input_batch_dimension) / batch_group_count
se forresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
se forresult_dim = output_feature_dimension
.num_windows
, caso contrário:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se a operação usar tensores não quantizados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se a operação usar tensores quantizados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, entãoquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se for
is_per_axis_quantized(result)
, entãoquantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
, entãois_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosseno
Semântica
Executa a operação de cosseno com elementos no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
cos
do IEEE-754. - Para números complexos: cosseno complexo.
- Para tipos quantizados:
dequantize_op_quantize(cosine, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semântica
Executa a contagem elemento a elemento do número de zero bits iniciais no tensor operand
e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor do tipo inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result)
.
Exemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semântica
Encapsula uma operação definida pela implementação call_target_name
que usa
inputs
e called_computations
e produz results
. has_side_effect
, backend_config
e api_version
podem ser usados para fornecer outros metadados definidos pela implementação.
No momento, essa operação contém uma coleção de metadados bastante desorganizada que reflete a evolução orgânica da operação equivalente no compilador XLA. No futuro, planejamos unificar esses metadados (#741).
Entradas
Rótulo | Nome | Tipo |
---|---|---|
(I1) | inputs |
número variável de valores |
(I2) | call_target_name |
constante do tipo string |
(I3) | has_side_effect |
constante do tipo i1 |
(I4) | backend_config |
constante do tipo string |
(I5) | api_version |
constante do tipo si32 |
(I6) | called_computations |
número variado de constantes do tipo string |
Saídas
Nome | Tipo |
---|---|
results |
número variável de valores |
Exemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dividir
Semântica
Executa a divisão por elemento dos tensores de dividendo lhs
e divisor rhs
e
produz um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: divisão de números inteiros que produz o quociente algébico com qualquer parte fracionária descartada.
- Para pontos flutuantes:
division
do IEEE-754. - Para números complexos: divisão complexa.
- Para tipos quantizados:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de número inteiro, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor de número inteiro, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semântica
Calcula produtos escalares entre frações de lhs
e frações de rhs
e produz um
tensor result
.
Mais formalmente, result[result_index] = dot_product
, em que:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
, em quesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
esize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
Para tipos quantizados, executa dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Para tipos quantizados híbridos, executa hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
controla a compensação entre velocidade e precisão para
cálculos em back-ends de aceleradores. Pode ser um dos seguintes. No
momento, a semântica desses valores de tipo enumerado está subespecificada, mas estamos
planejando resolver isso no
#755:
DEFAULT
: cálculo mais rápido, mas aproximação menos precisa do número original.HIGH
: cálculo mais lento, mas com aproximação mais precisa do número original.HIGHEST
: cálculo mais lento, mas aproximação mais precisa do número original.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C5 a C6), (C9 a C10), (C12 a C14), (C17 a C18) e (C20) |
(I2) | rhs |
tensor ou quantizado | (C7 a C10) e C12 a C20. |
(I3) | lhs_batching_dimensions |
constante do tensor unidimensional do tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
constante do tensor unidimensional do tipo si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
constante do tensor unidimensional do tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
constante do tensor unidimensional do tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
número variável de enumerações de DEFAULT , HIGH e HIGHEST |
(C11) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C12), (C14) (C18 a C20) |
Restrições
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Se a operação usar tensores não quantizados:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Se a operação usar tensores quantizados:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Se for
is_per_axis_quantized(rhs)
, entãoquantization_dimension(rhs)
não está emrhs_contracting_dimensions
. - Se
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Se for
is_per_tensor_quantized(rhs)
, entãois_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
Exemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semântica
Essa operação é funcionalmente idêntica a broadcast_in_dim, mas o formato resultante é especificado dinamicamente por output_dimensions
.
A operação também aceita atributos opcionais known_expanding_dimensions
e known_non_expanding_dimensions
para expressar conhecimento estático sobre o comportamento de expansão das dimensões.
Se não for especificado, presume-se que todas as dimensões sejam possivelmente expansíveis.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C2), (C5 a C6) |
(I2) | output_dimensions |
Tensor unidimensional de tipo inteiro | (C7) |
(I3) | broadcast_dimensions |
Tensor constante unidimensional de tipo inteiro | (C2 a C6) |
(I4) | known_expanding_dimensions |
Tensor constante unidimensional de tipo inteiro | (C8 a C9). |
(I5) | known_non_expanding_dimensions |
Tensor constante unidimensional de tipo inteiro | (C8 a C9). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1), (C3) e (C5 a C7) |
Restrições
- (C1)
element_type(result)
é fornecido por:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, exceto quequantization_dimension(operand)
,scales(operand)
ezero_points(operand)
podem ser diferentes da resp.quantization_dimension(result)
,scales(result)
ezero_points(result)
. Caso contrário.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Para todos os
d
emaxes(operand)
:dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se for
dim(operand, quantization_dimension(operand)) = 1
, entãoscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_non_expanding_dimensions < rank(operand)
.
Exemplos
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semântica
Essa operação é funcionalmente idêntica à operação de convolução, mas o padding é especificado dinamicamente por padding
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31) (C33) |
(I2) | rhs |
tensor ou quantizado | (C1), (C14 a C16), (C26 a C28) e (C30 a C33) |
(I3) | padding |
Tensor bidimensional de tipo inteiro | (C4) |
(I4) | window_strides |
constante do tensor unidimensional do tipo si64 |
(C2 a C3) |
(I5) | lhs_dilation |
constante do tensor unidimensional do tipo si64 |
(C5 a C6) |
(I6) | rhs_dilation |
constante do tensor unidimensional do tipo si64 |
(C7 a C8) |
(I7) | window_reversal |
constante do tensor unidimensional do tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante do tipo si64 |
(C10) e (C13). |
(I9) | input_feature_dimension |
constante do tipo si64 |
(C11) (C13 a C14). |
(I10) | input_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C12) e (C13). |
(I11) | kernel_input_feature_dimension |
constante do tipo si64 |
(C14) e (C18). |
(I12) | kernel_output_feature_dimension |
constante do tipo si64 |
(C15 a C16), (C18) e (C28) |
(I13) | kernel_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C17 a C18) |
(I14) | output_batch_dimension |
constante do tipo si64 |
(C20) |
(I15) | output_feature_dimension |
constante do tipo si64 |
(C20) e (C29) |
(I16) | output_spatial_dimensions |
constante do tensor unidimensional do tipo si64 |
(C19 a C20) |
(I17) | feature_group_count |
constante do tipo si64 |
(C11), (C14), (C16), (C21) (C23) |
(I18) | batch_group_count |
constante do tipo si64 |
(C10), (C15), (C22) (C23) |
(I19) | precision_config |
número variável de enumerações de DEFAULT , HIGH e HIGHEST |
(C24) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C25 a C27), (C29) e (C31 a C33) |
Restrições
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Considerando
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Considerando
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Considerando
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
é definido como:dim(lhs, input_batch_dimension) / batch_group_count
se forresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
se forresult_dim = output_feature_dimension
.num_windows
, caso contrário:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se a operação usar tensores não quantizados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se a operação usar tensores quantizados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, entãoquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se for
is_per_axis_quantized(result)
, entãoquantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
, entãois_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemplos
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semântica
Essa operação é funcionalmente idêntica à operação de reunir, com o slice_sizes
especificado dinamicamente como um valor.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C7), (C10-C12) (C14) |
(I2) | start_indices |
tensor do tipo inteiro | (C2), (C3) e (C13) |
(I3) | slice_sizes |
Tensor unidimensional de tipo inteiro | (C8) (C11 a C13). |
(I4) | offset_dims |
constante do tensor unidimensional do tipo si64 |
(C1), (C4 a C5) e (C13) |
(I5) | collapsed_slice_dims |
constante do tensor unidimensional do tipo si64 |
(C1), (C6 a C8) e (C13) |
(I6) | start_index_map |
constante do tensor unidimensional do tipo si64 |
(C3), (C9) (C10) |
(I7) | index_vector_dim |
constante do tipo si64 |
(C2), (C3) e (C13) |
(I8) | indices_are_sorted |
constante do tipo i1 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5) (C13 a C14). |
Restrições
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
, em que:batch_dim_sizes = shape(start_indices)
, exceto que o tamanho da dimensão destart_indices
correspondente aindex_vector_dim
não está incluído.offset_dim_sizes = shape(slice_sizes)
, exceto que os tamanhos de dimensão emslice_sizes
correspondentes acollapsed_slice_dims
não são incluídos.combine
colocabatch_dim_sizes
nos eixos correspondentes abatch_dims
eoffset_dim_sizes
nos eixos correspondentes aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Exemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semântica
Essa operação é funcionalmente idêntica à operação iota, mas o formato do resultado é especificado dinamicamente por output_shape
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | output_shape |
Tensor unidimensional de tipo inteiro | (C1) e (C2). |
(I2) | iota_dimension |
si64 |
(C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C2) |
Restrições
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Exemplos
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semântica
Essa operação é funcionalmente idêntica à operação de preenchimento, mas com edge_padding_low
, edge_padding_high
e interior_padding
especificados dinamicamente como valores.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2) e (C4) |
(I2) | padding_value |
Tensor 0 dimensional ou tensor quantizado por tensor | (C1) |
(I3) | edge_padding_low |
Tensor unidimensional de tipo inteiro | (C1) e (C4). |
(I4) | edge_padding_high |
Tensor unidimensional de tipo inteiro | (C1) e (C4). |
(I5) | interior_padding |
Tensor unidimensional de tipo inteiro | (C2 a C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3 a C6) |
Restrições
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semântica
Essa operação é funcionalmente idêntica à operação de remodelamento, mas o formato resultante é especificado dinamicamente por output_shape
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C3) |
(I2) | output_shape |
Tensor unidimensional de tipo inteiro | (C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1 a C4) |
Restrições
- (C1)
element_type(result)
é fornecido por:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, exceto quequantization_dimension(operand)
equantization_dimension(result)
podem ser diferentes.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semântica
Extrai uma fração do operand
usando índices iniciais calculados dinamicamente
e produz um tensor result
. start_indices
contém os índices iniciais da fração para cada dimensão sujeita a possíveis ajustes, e slice_sizes
contém os tamanhos da fatia para cada dimensão. Mais formalmente,
result[result_index] = operand[operand_index]
, em que:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2) e (C4) |
(I2) | start_indices |
número variado de tensores 0-dimensionais do tipo inteiro | (C2) e (C3). |
(I3) | slice_sizes |
constante do tensor unidimensional do tipo si64 |
(C2), (C4) (C5) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) e (C5). |
Restrições
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Exemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semântica
Produz um tensor result
que é igual ao tensor operand
, exceto que a fração que começa em start_indices
é atualizada com os valores em update
.
Mais formalmente, result[result_index]
é definido como:
update[update_index]
se0 <= update_index < shape(update)
, em que:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- Caso contrário,
operand[result_index]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1 a C4) e (C6) |
(I2) | update |
tensor ou tensor quantizado por tensor | (C2), (C3) e (C6) |
(I3) | start_indices |
número variado de tensores 0-dimensionais do tipo inteiro | (C4) e (C5). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Exemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semântica
Executa uma operação exponencial com elemento no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
exp
do IEEE-754. - Para números complexos: exponencial complexa.
- Para tipos quantizados:
dequantize_op_quantize(exponential, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semântica
Executa a operação exponencial de elemento menos um no tensor operand
e
gera um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
expm1
do IEEE-754. - Para números complexos: exponencial complexa menos um.
- Para tipos quantizados:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft.
Semântica
Executa as transformações de Fourier direta e inversa para entradas/saídas reais e complexas.
fft_type
é um destes:
FFT
: encaminha FFT complexa para complexa.IFFT
: FFT inversa de complexa para complexa.RFFT
: encaminha FFT real para complexa.IRFFT
: FFT inversa entre real e complexa (ou seja, toma complexa, retorna real).
Mais formalmente, considerando a função fft
, que usa tensores unidimensionais de tipos complexos como entrada, produz tensores unidimensionais do mesmo tipo que a saída e calcula a transformação discreta de Fourier:
Para fft_type = FFT
, result
é definido como o resultado final de uma série de cálculos L em que L = size(fft_length)
. Por exemplo, para L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Além disso, considerando a função ifft
, que tem a mesma assinatura de tipo e
calcula o inverso de fft
:
Para fft_type = IFFT
, result
é definido como o inverso dos cálculos
para fft_type = FFT
. Por exemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Além disso, considerando a função rfft
, que usa tensores unidimensionais de
tipos de ponto flutuante, produz tensores unidimensionais de tipos complexos da
mesma semântica de ponto flutuante e funciona da seguinte maneira:
rfft(real_operand) = truncated_result
, ondecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
Quando a transformação discreta de Fourier é calculada para operandos reais, os primeiros
elementos N/2 + 1
do resultado definem inequivocamente o restante do resultado,
de modo que o resultado de rfft
é truncado para evitar a computação de elementos redundantes.
Para fft_type = RFFT
, result
é definido como o resultado final de uma série de cálculos L em que L = size(fft_length)
. Por exemplo, para L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Por fim, considerando a função irfft
, que tem a mesma assinatura de tipo e calcula o inverso de rfft
:
Para fft_type = IRFFT
, result
é definido como o inverso dos cálculos
para fft_type = RFFT
. Por exemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enumeração de FFT , IFFT , RFFT e IRFFT |
(C2) e (C5). |
(I3) | fft_length |
constante do tensor unidimensional do tipo si64 |
(C1), (C3) e (C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo | (C2), (C4) (C5) |
Restrições
- (C1)
size(fft_length) <= rank(operand)
. - (C2) A relação entre os tipos de elementos
operand
eresult
varia:- Se
fft_type = FFT
,element_type(operand)
eelement_type(result)
tiverem o mesmo tipo complexo. - Se
fft_type = IFFT
,element_type(operand)
eelement_type(result)
tiverem o mesmo tipo complexo. - Se
fft_type = RFFT
,element_type(operand)
será um tipo de ponto flutuante eelement_type(result)
será um tipo complexo da mesma semântica de ponto flutuante. - Se
fft_type = IRFFT
,element_type(operand)
for um tipo complexo eelement_type(result)
for um tipo de ponto flutuante da mesma semântica de ponto flutuante.
- Se
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Se estiver entre
operand
eresult
, houver um tensorreal
de um tipo de ponto flutuante, entãoshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, exceto:- Se for
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Se for
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Se for
Exemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
Andar
Semântica
Executa o mínimo elemento do tensor operand
e produz um tensor result
.
Implementa a operação roundToIntegralTowardNegative
da especificação
IEEE-754. Para tipos quantizados, executa
dequantize_op_quantize(floor, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
reunir
Semântica
Coleta frações do tensor operand
dos deslocamentos especificados em start_indices
e produz um tensor result
.
O diagrama a seguir mostra como os elementos em result
são mapeados em elementos em
operand
usando um exemplo concreto. O diagrama escolhe alguns exemplos de índices de result
e explica em detalhes a quais índices de operand
eles correspondem.
Mais formalmente, result[result_index] = operand[operand_index]
, em que:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
é definido como:start_indices[bi0, ..., :, ..., biN]
, em quebi
são elementos individuais embatch_index
e:
é inserido no índice deindex_vector_dim
, seindex_vector_dim
<rank(start_indices)
.- Caso contrário,
[start_indices[batch_index]]
.
- Para
d_operand
emaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sed_operand = start_index_map[d_start]
.- Caso contrário,
full_start_index[d_operand] = 0
.
- Para
d_operand
emaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_operand = operand_batching_dims[i_batching]
ed_start = start_indices_batching_dims[i_batching]
.- Caso contrário,
full_batching_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, em queoi
são elementos individuais emoffset_index
, e0
é inserido em índices decollapsed_slice_dims
eoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
Se indices_are_sorted
for true
, a implementação poderá presumir que
start_indices
estão classificados em relação a start_index_map
. Caso contrário, o
comportamento será indefinido. De forma mais formal, para todos os i1 < i2
de indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C8), (C11), (C17), (C19 a C21) (C23) |
(I2) | start_indices |
tensor do tipo inteiro | (C2 a C3), (C14), (C17) e (C22) |
(I3) | offset_dims |
constante do tensor unidimensional do tipo si64 |
(C1), (C4 a C5) e (C22). |
(I4) | collapsed_slice_dims |
constante do tensor unidimensional do tipo si64 |
(C1), (C6 a C9) e (C22). |
(I5) | operand_batching_dims |
constante do tensor unidimensional do tipo si64 |
(C1), (C6), (C10-C12), (C16 a C18) (C22) |
(I6) | start_indices_batching_dims |
constante do tensor unidimensional do tipo si64 |
(C13 a C17) |
(I7) | start_index_map |
constante do tensor unidimensional do tipo si64 |
(C3) (C18 a C19) |
(I8) | index_vector_dim |
constante do tipo si64 |
(C2 a C3), (C15) e (C22) |
(I9) | slice_sizes |
constante do tensor unidimensional do tipo si64 |
(C9), (C12) (C20-C22) |
(I10) | indices_are_sorted |
constante do tipo i1 |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C5) e C22 a C23. |
Restrições
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
, em que:batch_dim_sizes = shape(start_indices)
, exceto que o tamanho da dimensão destart_indices
correspondente aindex_vector_dim
não está incluído.offset_dim_sizes = slice_sizes
, exceto pelo fato de que os tamanhos de dimensão emslice_sizes
correspondentes acollapsed_slice_dims
eoperand_batching_dims
não são incluídos.combine
colocabatch_dim_sizes
nos eixos correspondentes abatch_dims
eoffset_dim_sizes
nos eixos correspondentes aoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Exemplos
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semântica
Produz o tamanho do dimension
especificado do operand
. Mais formalmente,
result = dim(operand, dimension)
. A semântica diz respeito apenas ao componente
de forma do tipo. O tipo de elemento pode ser qualquer coisa.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1) |
(I2) | dimension |
constante do tipo si64 |
(C1) |
Saídas
Nome | Tipo |
---|---|
result |
Tensor 0 dimensional do tipo si32 |
Restrições
- (C1)
0 <= dimension < rank(operand)
.
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semântica
Extrai o elemento na posição index
da tupla operand
e produz um
result
. Mais formalmente, result = operand[index]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tuple | (C1) e (C2). |
(I2) | index |
constante do tipo si32 |
(C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
qualquer tipo com suporte | (C2) |
Restrições
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Exemplos
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Semântica
Produz a saída da execução de exatamente uma função de true_branch
ou
false_branch
, dependendo do valor de pred
. Mais formalmente, result =
pred ? true_branch() : false_branch()
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | pred |
Tensor 0 dimensional do tipo i1 |
|
(I2) | true_branch |
função | (C1 a C3) |
(I3) | false_branch |
função | (C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores, tensores quantizados ou tokens | (C3) |
Restrições
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Exemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imagem
Semântica
Extrai a parte imaginária, com elementos, do operand
e produz um
tensor result
. Mais formalmente, para cada elemento x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo de ponto flutuante | (C1) e (C2). |
Restrições
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
é definido como:complex_element_type(element_type(operand))
se foris_complex(operand)
.- Caso contrário,
element_type(operand)
.
Exemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
entrada
Semântica
Lê dados da entrada e produz results
.
A semântica de infeed_config
é definida pela implementação.
results
consistem em valores de payload que vêm primeiro e um token que vem
por último. No futuro, planejamos dividir o payload e o token em duas
saídas separadas para melhorar a clareza
(#670).
Entradas
Rótulo | Nome | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante do tipo string |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores, tensores quantizados ou tokens | (C1 a C3) |
Restrições
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Exemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semântica
Preenche um tensor output
com valores em ordem crescente, começando de zero
ao longo da dimensão iota_dimension
. Mais formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
output |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
0 <= iota_dimension < rank(output)
.
Exemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semântica
Executa a verificação elemento-chave se o valor em x
é finito (ou seja, não é
+Inf, -Inf ou NaN) e produz um tensor y
. Implementa a operação isFinite
da especificação IEEE-754. Para tipos quantizados, o resultado é
sempre true
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | x |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
y |
tensor de tipo booleano | (C1) |
Restrições
- (C1)
shape(x) = shape(y)
.
Exemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semântica
Executa uma operação de logaritmo de elemento no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
log
do IEEE-754. - Para números complexos: logaritmo complexo.
- Para tipos quantizados:
dequantize_op_quantize(log, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semântica
Executa o logaritmo de elemento mais uma operação no tensor operand
e
gera um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
logp1
do IEEE-754. - Para números complexos: logaritmo complexo mais um.
- Para tipos quantizados:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semântica
Executa uma operação logística por elementos no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
division(1, addition(1, exp(-x)))
do IEEE-754. - Para números complexos: logística complexa.
- Para tipos quantizados:
dequantize_op_quantize(logistic, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semântica
Aplica uma função de mapa computation
a inputs
ao longo do dimensions
e produz um tensor result
.
Mais formalmente, result[result_index] = computation(inputs...[result_index])
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1 a C4) |
(I2) | dimensions |
constante do tensor unidimensional do tipo si64 |
(C3) |
(I3) | computation |
função | (C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) e (C4). |
Restrições
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
tem o tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, em queEi = element_type(inputs[i])
eE' = element_type(result)
.
Exemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
máximo
Semântica
Executa a operação máxima por elemento nos tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: máximo de números inteiros.
- Para pontos flutuantes:
maximum
do IEEE-754. - Para números complexos: máximo lexicográfico para o par
(real, imaginary)
. Impor uma ordenação em números complexos envolve uma semântica surpreendente. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560). - Para tipos quantizados:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
mínimo
Semântica
Executa uma operação mínima com elementos nos tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: mínimo de números inteiros.
- Para pontos flutuantes:
minimum
do IEEE-754. - Para números complexos: mínimo lexicográfico para o par
(real, imaginary)
. Impor uma ordenação em números complexos envolve uma semântica surpreendente. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560). - Para tipos quantizados:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplicar
Semântica
Executa o produto com elementos de dois tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: AND lógico.
- Para números inteiros: multiplicação de números inteiros.
- Para pontos flutuantes:
multiplication
do IEEE-754. - Para números complexos: multiplicação complexa.
- Para tipos quantizados:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Semântica
Executa a negação elemento-do do tensor operand
e produz um tensor
result
. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros com assinatura: negação de número inteiro.
- Para números inteiros sem assinatura: bitcast para número inteiro assinado, negação de número inteiro, bitcast de volta para número inteiro não assinado.
- Para pontos flutuantes:
negate
do IEEE-754. - Para números complexos: negação complexa.
- Para tipos quantizados:
dequantize_op_quantize(negate, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
não
Semântica
Executa NOT elemento do tensor operand
e produz um tensor result
.
Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: NOT lógico.
- Para números inteiros: NOT bit a bit.
Argumentos
Nome | Tipo | Restrições |
---|---|---|
operand |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result)
.
Exemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semântica
Garante que as operações que produzem o operand
sejam executadas antes de qualquer
operação que dependa de result
e impede que as transformações do compilador
movam operações pela barreira. Fora isso, a operação é
uma identidade, ou seja, result = operand
.
Argumentos
Nome | Tipo | Restrições |
---|---|---|
operand |
número variado de tensores, tensores quantizados por tensor ou tokens | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
número variado de tensores, tensores quantizados por tensor ou tokens | (C1) |
Restrições
- (C1)
type(operand...) = type(result...)
.
Exemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
ou
Semântica
Executa OR com elemento de dois tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: OR lógico.
- Para números inteiros: OR bit a bit.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de tipo inteiro ou booleano | (C1) |
(I2) | rhs |
tensor de tipo inteiro ou booleano | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo inteiro ou booleano | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
saída
Semântica
Grava inputs
na saída de saída e produz um token result
.
A semântica de outfeed_config
é definida pela implementação.
Entradas
Rótulo | Nome | Tipo |
---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados |
(I2) | token |
token |
(I3) | outfeed_config |
constante do tipo string |
Saídas
Nome | Tipo |
---|---|
result |
token |
Exemplos
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
almofada
Semântica
Expande operand
ao redor do tensor e também entre os elementos
dele com o padding_value
especificado.
edge_padding_low
e edge_padding_high
especificam a quantidade de padding adicionado
na parte inferior (ao lado do índice 0) e na parte superior (próxima ao índice mais alto) de
cada dimensão, respectivamente. A quantidade de padding pode ser negativa, em que o
valor absoluto do padding negativo indica o número de elementos que serão removidos
da dimensão especificada.
interior_padding
especifica a quantidade de padding adicionado entre dois elementos em cada dimensão que não podem ser negativos. O preenchimento interno ocorre
antes do preenchimento da borda negativo, de modo que o preenchimento da borda negativo remova elementos do
operando com preenchimento interno.
Mais formalmente, result[result_index]
é definido como:
operand[operand_index]
seresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- Caso contrário,
padding_value
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2) e (C4) |
(I2) | padding_value |
Tensor 0 dimensional ou tensor quantizado por tensor | (C1) |
(I3) | edge_padding_low |
constante do tensor unidimensional do tipo si64 |
(C1) e (C4). |
(I4) | edge_padding_high |
constante do tensor unidimensional do tipo si64 |
(C1) e (C4). |
(I5) | interior_padding |
constante do tensor unidimensional do tipo si64 |
(C2 a C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C3 a C6) |
Restrições
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semântica
Produz partition_id
do processo atual.
Saídas
Nome | Tipo |
---|---|
result |
Tensor 0 dimensional do tipo ui32 |
Exemplos
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semântica
Executa a contagem em elementos do número de bits definidos no tensor operand
e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor do tipo inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo inteiro | (C1) |
Restrições
- (C1)
type(operand) = type(result)
.
Exemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
potência
Semântica
Executa a exponenciação com elementos do tensor lhs
por rhs
e
gera um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: expressão inteira.
- Para pontos flutuantes:
pow
do IEEE-754. - Para números complexos: exponenciação complexa.
- Para tipos quantizados:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semântica
Extrai a parte real, por elemento, do operand
e produz um tensor
result
. Mais formalmente, para cada elemento x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo | (C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo de ponto flutuante | (C1) e (C2). |
Restrições
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
é definido como:complex_element_type(element_type(operand))
se foris_complex(operand)
.- Caso contrário,
element_type(operand)
.
Exemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semântica
Recebe dados de um canal com channel_id
e produz results
.
Se is_host_transfer
for true
, a operação transferirá dados do
host. Caso contrário, a transferência será feita de outro dispositivo. O que isso significa é
definido pela implementação. Essa flag duplica as informações fornecidas em
channel_type
. Portanto, no futuro, planejamos manter apenas uma delas
(#666).
results
consistem em valores de payload que vêm primeiro e um token que vem
por último. No futuro, planejamos dividir o payload e o token em duas
saídas separadas para melhorar a clareza
(#670).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante do tipo si64 |
|
(I3) | channel_type |
tipo enumerado de DEVICE_TO_DEVICE e HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
constante do tipo i1 |
(C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores, tensores quantizados ou tokens | (C2 a C4) |
Restrições
- (C1)
channel_type
é definido como:HOST_TO_DEVICE
seis_host_transfer = true
,- Caso contrário,
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Exemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semântica
Aplica uma função de redução body
a inputs
e init_values
ao longo do
dimensions
e produz tensores results
.
A ordem das reduções é definida pela implementação, o que significa que body
e
init_values
precisam formar um monoide para garantir que a operação produza os
mesmos resultados para todas as entradas em todas as implementações. No entanto, essa condição
não é válida para muitas reduções comuns. Por exemplo, a adição de ponto flutuante para body
e zero para init_values
não formam um monoide porque a adição de ponto flutuante não é associativa.
Mais formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, em que:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, em que:
são inseridos emdimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
para alguma árvore bináriaschedule
em que:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
é uma árvore binária completa definida pela implementação cuja travessia em ordem consiste em:- Valores
input_slices_converted...[index]
, para todos osindex
emindex_space(input_slices_converted)
na ordem lexicográfica crescente deindex
. - Intercalada com uma quantidade definida pela implementação de
init_values_converted
em posições definidas pela implementação.
- Valores
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1 a C4), (C6) e (C7) |
(I2) | init_values |
número variado de tensores 0-dimensionais ou tensores quantizados por tensor | (C2) e (C3). |
(I3) | dimensions |
constante do tensor unidimensional do tipo si64 |
(C4), (C5) e (C7) |
(I4) | body |
função | (C6) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores ou tensores quantizados por tensor | (C3), (C7), (C8) |
Restrições
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
tem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, em queis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, exceto pelo fato de que os tamanhos deinputs...
correspondentes adimensions
não são incluídos. - (C8)
element_type(results[i]) = Ei
para todos osi
em[0,N)
.
Exemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semântica
Executa a conversão com elementos de operand
para outro tipo de ponto flutuante
que usa exponent_bits
e mantissa_bits
e de volta ao tipo de ponto flutuante
original e produz um tensor output
.
Mais formalmente:
- Os bits de mantissa do valor original são atualizados para arredondar o valor
original para o valor mais próximo representável com
mantissa_bits
usando semânticaroundToIntegralTiesToEven
. - Em seguida, se
mantissa_bits
for menor que o número de bits de mantissa do valor original, os bits de mantissa serão truncados emmantissa_bits
. - Em seguida, se os bits expoentes do resultado intermediário não se encaixarem no
intervalo fornecido por
exponent_bits
, o resultado intermediário vai ultrapassar o infinito usando o sinal original ou os underflows para zero usando o sinal original. - Para tipos quantizados, executa
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
(I2) | exponent_bits |
constante do tipo si32 |
(C2) |
(I3) | mantissa_bits |
constante do tipo si32 |
(C3) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
output |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Exemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semântica
Dentro de cada grupo de processos na grade de processos StableHLO, realiza a redução usando computations
sobre os valores do tensor operand
de cada processo, divide o resultado da redução junto a scatter_dimension
em partes e dispersa
as partes divididas entre os processos para produzir o result
.
A operação divide a grade do processo StableHLO em process_groups
, que é
definida da seguinte maneira:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Depois, em cada process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
para todos ossender
emprocess_group
, em quereceiver_index = process_group.index(receiver)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
constante do tipo si64 |
(C1), (C2) e (C8) |
(I3) | replica_groups |
Constante do tensor bidimensional do tipo si64 |
(C3 a C5) |
(I4) | channel_id |
constante do tipo si64 |
(C6) |
(I5) | use_global_device_ids |
constante do tipo i1 |
(C6) |
(I6) | computation |
função | (C7) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C8 a C9). |
Restrições
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
é definido como:num_replicas
secross_replica
for usado.num_replicas
secross_replica_and_partition
for usado.num_processes
seflattened_ids
for usado.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Se for
use_global_device_ids = true
, entãochannel_id > 0
. - (C7)
computation
tem o tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, em queis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, exceto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Exemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semântica
Aplica uma função de redução body
a janelas de inputs
e init_values
e produz results
.
O diagrama a seguir mostra como os elementos em results...
são calculados de
inputs...
usando um exemplo concreto.
Mais formalmente,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(consulte reduzir), em que:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1 a C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
número variado de tensores 0-dimensionais ou tensores quantizados por tensor | (C1) e (C13). |
(I3) | window_dimensions |
constante do tensor unidimensional do tipo si64 |
(C4), (C5) (C15) |
(I4) | window_strides |
constante do tensor unidimensional do tipo si64 |
(C6), (C7) e (C15) |
(I5) | base_dilations |
constante do tensor unidimensional do tipo si64 |
(C8), (C9) (C15) |
(I6) | window_dilations |
constante do tensor unidimensional do tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante do tensor bidimensional do tipo si64 |
(C12) e (C15) |
(I8) | body |
função | (C13) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores ou tensores quantizados por tensor | (C1), (C14 a C16). |
Restrições
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
tem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, em queis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
, em que:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
para todos osi
em[0,N)
.
Exemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
restante
Semântica
Executa o restante com elementos dos tensores lhs
e divisor rhs
e
produz um tensor result
.
Mais formalmente, o sinal do resultado é retirado do dividendo, e o valor absoluto do resultado é sempre menor que o valor absoluto do divisor.
O restante é calculado como lhs - d * rhs
, em que d
é fornecido por:
- Para números inteiros:
stablehlo.divide(lhs, rhs)
. - Para pontos flutuantes:
division(lhs, rhs)
do IEEE-754 com atributo de arredondamentoroundTowardZero
. - Para números complexos: a definir (#997).
- Para tipos quantizados:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Para tipos de elementos de ponto flutuante, essa operação é diferente da
operação remainder
da especificação IEEE-754, em que d
é um valor integral
mais próximo do valor exato de lhs/rhs
com vínculos a pares.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de número inteiro, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor de número inteiro, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semântica
Produz replica_id
do processo atual.
Saídas
Nome | Tipo |
---|---|
result |
Tensor 0 dimensional do tipo ui32 |
Exemplos
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
remodelar
Semântica
Executa a remodelação do tensor operand
para um tensor result
. Conceitualmente, significa manter a mesma representação canônica, mas alterar potencialmente a forma, por exemplo, de tensor<2x3xf32>
para tensor<3x2xf32>
ou tensor<6xf32>
.
Mais formalmente, result[result_index] = operand[operand_index]
, em que result_index
e operand_index
têm a mesma posição na ordem lexicográfica de index_space(result)
e index_space(operand)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C3) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1 a C3) |
Restrições
- (C1)
element_type(result)
é fornecido por:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, exceto quequantization_dimension(operand)
equantization_dimension(result)
podem ser diferentes.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Exemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
anular
Semântica
Inverte a ordem dos elementos em operand
ao longo do dimensions
especificado
e produz um tensor result
. Mais formalmente,
result[result_index] = operand[operand_index]
, em que:
operand_index[d] = dim(result, d) - result_index[d] - 1
sed
emdimensions
.- Caso contrário,
operand_index[d] = result_index[d]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1) e (C3). |
(I2) | dimensions |
constante do tensor unidimensional do tipo si64 |
(C2) e (C3). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) e (C3). |
Restrições
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Exemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
RG
Semântica
Gera números aleatórios usando o algoritmo rng_distribution
e produz um tensor result
de uma determinada forma shape
.
Se for rng_distribution = UNIFORM
, os números aleatórios serão gerados
seguindo a distribuição uniforme no intervalo [a, b)
. Se for a >= b
,
o comportamento será indefinido.
Se for rng_distribution = NORMAL
, os números aleatórios serão gerados seguindo a distribuição normal com média = a
e desvio padrão = b
.
Se for b < 0
, o comportamento será indefinido.
A maneira exata como os números aleatórios são gerados é definida pela implementação. Por exemplo, elas podem ou não ser deterministas, e podem ou não usar o estado oculto.
Em conversas com muitas partes interessadas, essa operação surgiu como efetivamente descontinuada. Portanto, no futuro, planejamos removê-la (#597).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | a |
Tensor 0 dimensional de número inteiro, booleano ou de ponto flutuante | (C1) e (C2). |
(I2) | b |
Tensor 0 dimensional de número inteiro, booleano ou de ponto flutuante | (C1) e (C2). |
(I3) | shape |
constante do tensor unidimensional do tipo si64 |
(C3) |
(I4) | rng_distribution |
tipo enumerado de UNIFORM e NORMAL |
(C2) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo inteiro, booleano ou ponto flutuante | (C1 a C3) |
Restrições
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Se for
rng_distribution = NORMAL
, entãois_float(a)
. - (C3)
shape(result) = shape
.
Exemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semântica
Retorna uma output
preenchida com bits aleatórios uniformes e um estado de saída atualizado
output_state
usando o algoritmo gerador de números pseudoaleatórios rng_algorithm
,
considerando um estado inicial initial_state
. A saída será uma função determinística de initial_state
, mas não será determinista entre as implementações.
rng_algorithm
é um destes:
DEFAULT
: algoritmo definido pela implementação.THREE_FRY
: variante definida pela implementação do algoritmo Threefry.*PHILOX
: variante definida pela implementação do algoritmo Philox.*
* Consulte: Salmon et al. SC 2011. Números aleatórios paralelos: tão fáceis quanto 1, 2, 3.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | rng_algorithm |
enumeração de DEFAULT , THREE_FRY e PHILOX |
(C2) |
(I2) | initial_state |
Tensor unidimensional do tipo ui64 |
(C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
output_state |
Tensor unidimensional do tipo ui64 |
(C1) |
output |
tensor de tipo inteiro ou de ponto flutuante |
Restrições
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
é definido como:- definido pela implementação se
rng_algorithm = DEFAULT
. 2
se forrng_algorithm = THREE_FRY
.2
ou3
serng_algorithm = PHILOX
.
- definido pela implementação se
Exemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semântica
Executa arredondamentos elementares para o número inteiro mais próximo, quebrando vínculos de zero, no tensor operand
, e produz um tensor result
. Implementa
a operação roundToIntegralTiesToAway
da especificação IEEE-754. Para
tipos quantizados, executa
dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semântica
Executa arredondamentos por elemento em direção ao número inteiro mais próximo, desfazendo vínculos com o número inteiro par no tensor operand
e produz um tensor result
. Implementa a operação roundToIntegralTiesToEven
da especificação
IEEE-754. Para tipos quantizados, executa
dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo de ponto flutuante ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semântica
Executa uma operação de raiz quadrada recíproca com elementos no tensor operand
e
gera um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
rSqrt
do IEEE-754. - Para números complexos: raiz quadrada recíproca complexa.
- Para tipos quantizados:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Semântica
Produz tensores results
, que são iguais a inputs
, exceto pelo fato de
várias frações especificadas por scatter_indices
serem atualizadas com os valores
updates
usando update_computation
.
O diagrama a seguir mostra como os elementos em updates...
são mapeados em elementos em
results...
usando um exemplo concreto. O diagrama escolhe alguns exemplos de índices de
updates...
e explica em detalhes a quais índices de results...
eles
correspondem.
Mais formalmente, para todos os update_index
em index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
é definido como:scatter_indices[si0, ..., :, ..., siN]
, em quesi
são elementos individuais emupdate_scatter_index
, e:
é inserido no índiceindex_vector_dim
, seindex_vector_dim
<rank(scatter_indices)
.- Caso contrário,
[scatter_indices[update_scatter_index]]
.
- Para
d_input
emaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
sed_input = scatter_dims_to_operand_dims[d_start]
.- Caso contrário,
full_start_index[d_input] = 0
.
- Para
d_input
emaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_input = input_batching_dims[i_batching]
ed_start = scatter_indices_batching_dims[i_batching]
.- Caso contrário,
full_batching_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, em quewi
são elementos individuais emupdate_window_index
, e0
é inserido em índices deinserted_window_dims
einput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Considerando isso, results = exec(schedule, inputs)
, em que:
schedule
é uma permutação definida pela implementação deindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
em que:- Se
result_index
estiver dentro dos limites parashape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
é uma cópia deresults
comresults...[result_index]
definido comoupdated_values...
.- Como alternativa, faça o seguinte:
updated_results = results
.
- Se
exec([], results) = results
.
Se indices_are_sorted
for true
, a implementação poderá presumir que
scatter_indices
estão classificados em relação a scatter_dims_to_operand_dims
.
Caso contrário, o comportamento será indefinido. Mais formalmente, para todos os i1 < i2
de
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Se unique_indices
for true
, a implementação poderá presumir que todos os
índices de result_index
que estão sendo dispersos são exclusivos. Se unique_indices
for
true
, mas os índices que estão sendo distribuídos não forem exclusivos, o comportamento será
indefinido.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1), (C2), (C4 a C6), (C11), (C13), (C18), (C21), (C23 a C24) |
(I2) | scatter_indices |
tensor do tipo inteiro | (C4), (C15), (C19), (C22) |
(I3) | updates |
número variado de tensores ou tensores quantizados por tensor | (C3 a C6) e (C8). |
(I4) | update_window_dims |
constante do tensor unidimensional do tipo si64 |
(C2), (C4) e (C7 a C8) |
(I5) | inserted_window_dims |
constante do tensor unidimensional do tipo si64 |
(C2), (C4) (C9 a C11) |
(I6) | input_batching_dims |
constante do tensor unidimensional do tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
constante do tensor unidimensional do tipo si64 |
(C14 a C18) |
(I8) | scatter_dims_to_operand_dims |
constante do tensor unidimensional do tipo si64 |
(C19 a C21) |
(I9) | index_vector_dim |
constante do tipo si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
constante do tipo i1 |
|
(I11) | unique_indices |
constante do tipo i1 |
|
(I12) | update_computation |
função | (C23) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores ou tensores quantizados por tensor | (C24 a C25) |
Restrições
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
, em que:update_scatter_dim_sizes = shape(scatter_indices)
, exceto pelo fato de o tamanho da dimensão descatter_indices
correspondente aindex_vector_dim
não ser incluído.update_window_dim_sizes <= shape(inputs[0])
, exceto que os tamanhos de dimensão eminputs[0]
correspondentes ainserted_window_dims
einput_batching_dims
não estão incluídos.combine
colocaupdate_scatter_dim_sizes
nos eixos correspondentes aupdate_scatter_dims
eupdate_window_dim_sizes
nos eixos correspondentes aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
tem o tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, em queis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
para todos osi
na[0,N)
.
Exemplos
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
select
Semântica
Produz um tensor result
em que cada elemento é selecionado do tensor on_true
ou
on_false
com base no valor do elemento correspondente de pred
.
Mais formalmente, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, em que pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Para tipos quantizados, executa
dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | pred |
tensor do tipo i1 |
(C1) |
(I2) | on_true |
tensor ou tensor quantizado por tensor | (C1 a C2) |
(I3) | on_false |
tensor ou tensor quantizado por tensor | (C2) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C2) |
Restrições
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Exemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semântica
Distribui os valores do tensor source
usando scatter
com base no
resultado de reduce_window
do tensor input
usando select
e produz
um tensor result
.
O diagrama a seguir mostra como os elementos em result
são calculados de
operand
e source
usando um exemplo concreto.
Mais formalmente:
selected_values = reduce_window_without_init(...)
pelas seguintes entradas:inputs = [operand].
window_dimensions
,window_strides
epadding
, que são usados no estado em que se encontram.base_dilations = windows_dilations = 1
.body
é definido como:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
em que
E = element_type(operand)
ereduce_window_without_init
funcionam exatamente comoreduce_window
, exceto pelo fato de que oschedule
dareduce
subjacente (consulte reduzir) não inclui valores init. No momento, não é especificado o que acontece se a janela correspondente não tiver valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
, em que:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
seselected_values[source_index]
tiver o elementooperand
deoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1 a C4), (C6) e (C8 a C11) |
(I2) | source |
tensor ou tensor quantizado por tensor | (C1) e (C2). |
(I3) | init_value |
Tensor 0 dimensional ou tensor quantizado por tensor | (C3) |
(I4) | window_dimensions |
constante do tensor unidimensional do tipo si64 |
(C2), (C4) (C5) |
(I5) | window_strides |
constante do tensor unidimensional do tipo si64 |
(C2), (C6) e (C7) |
(I6) | padding |
Constante do tensor bidimensional do tipo si64 |
(C2) e (C8). |
(I7) | select |
função | (C9) |
(I8) | scatter |
função | (C10) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C11 a C12) |
Restrições
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
, em que:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
tem o tipo(tensor<E>, tensor<E>) -> tensor<i1>
, em queE = element_type(operand)
. - (C10)
scatter
tem o tipo(tensor<E>, tensor<E>) -> tensor<E>
, em queis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Exemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
send
Semântica
Envia inputs
para um canal channel_id
e produz um token result
.
Se is_host_transfer
for true
, a operação transfere dados para o host. Caso contrário, ela vai transferir os dados para outro dispositivo. O que isso significa é
definido pela implementação. Essa flag duplica as informações fornecidas em
channel_type
. Portanto, no futuro, planejamos manter apenas uma delas
(#666).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados | |
(I2) | token |
token |
|
(I3) | channel_id |
constante do tipo si64 |
|
(I4) | channel_type |
tipo enumerado de DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
constante do tipo i1 |
(C1) |
Saídas
Nome | Tipo |
---|---|
result |
token |
Restrições
- (C1)
channel_type
é definido como:DEVICE_TO_HOST
seis_host_transfer = true
,- Caso contrário,
DEVICE_TO_DEVICE
.
Exemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semântica
Executa a operação de deslocamento para a esquerda (elemento) no tensor lhs
por número rhs
de bits e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor do tipo inteiro | (C1) |
(I2) | rhs |
tensor do tipo inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semântica
Executa a operação aritmética de deslocamento para a direita com elementos no tensor lhs
por
número rhs
de bits e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor do tipo inteiro | (C1) |
(I2) | rhs |
tensor do tipo inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semântica
Executa a operação lógica de deslocamento para a direita no elemento lhs
pelo número
rhs
de bits e produz um tensor result
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor do tipo inteiro | (C1) |
(I2) | rhs |
tensor do tipo inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
de igual.
Semântica
Retorna o sinal de operand
com elementos e produz um tensor result
.
Mais formalmente, para cada elemento x
, a semântica pode ser expressa usando
a sintaxe do Python da seguinte maneira:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para tipos quantizados, executa
dequantize_op_quantize(sign, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semântica
Executa uma operação de seno por elemento no tensor operand
e produz um tensor
result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
sin
do IEEE-754. - Para números complexos: seno complexo.
- Para tipos quantizados:
dequantize_op_quantize(sine, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semântica
Extrai uma fração do operand
usando índices iniciais calculados estaticamente
e produz um tensor result
. start_indices
contém os índices iniciais da
fração para cada dimensão, limit_indices
contém os índices finais
(exclusivos) da fração de cada dimensão, e strides
contém os passos
de cada dimensão.
Mais formalmente, result[result_index] = operand[operand_index]
, em que
operand_index = start_indices + result_index * strides
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou tensor quantizado por tensor | (C1 a C3) e C5. |
(I2) | start_indices |
constante do tensor unidimensional do tipo si64 |
(C2), (C3) e (C5) |
(I3) | limit_indices |
constante do tensor unidimensional do tipo si64 |
(C2), (C3) e (C5) |
(I4) | strides |
constante do tensor unidimensional do tipo si64 |
(C2) e (C4). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou tensor quantizado por tensor | (C1) e (C5). |
Restrições
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Exemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sort
Semântica
Classifica fatias unidimensionais de inputs
ao longo da dimensão dimension
juntas,
de acordo com um comparator
, e produz results
.
Ao contrário de entradas semelhantes em outras operações, dimension
permite valores negativos, com a semântica descrita abaixo. No futuro, isso pode ser proibido
por motivos de consistência
(#1377).
Se is_stable
for verdadeiro, a classificação será estável, ou seja, a ordem relativa de
elementos considerados iguais pelo comparador será preservada. Para o caso
em que há uma única entrada, dois elementos e1
e e2
são considerados
iguais pelo comparador somente se
comparator(e1, e2) = comparator(e2, e1) = false
. Confira a formalização abaixo para saber como isso generaliza para várias entradas.
Mais formalmente, para todos os result_index
em index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, em queriN
são elementos individuais emresult_index
, e:
é inserido emadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- em que
sort
classifica uma fração unidimensional em ordem não decrescente esperando quecomparator_together
retornetrue
se o argumento do lado esquerdo for menor que o segundo argumento à direita. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | inputs |
número variado de tensores ou tensores quantizados por tensor | (C1 a C5) |
(I2) | dimension |
constante do tipo si64 |
(C4) |
(I3) | is_stable |
constante do tipo i1 |
|
(I4) | comparator |
função | (C5) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores ou tensores quantizados por tensor | (C2) e (C3). |
Restrições
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, em queR = rank(inputs[0])
. - (C5)
comparator
tem o tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, em queEi = element_type(inputs[i])
.
Exemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semântica
Executa uma operação de raiz quadrada com elemento no tensor operand
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
squareRoot
do IEEE-754. - Para números complexos: raiz quadrada complexa.
- Para tipos quantizados:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semântica
Executa a subtração com elementos de dois tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para números inteiros: subtração de números inteiros.
- Para pontos flutuantes:
subtraction
do IEEE-754. - Para números complexos: subtração complexa.
- Para tipos quantizados:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
(I2) | rhs |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de número inteiro, ponto flutuante, tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
Tanh
Semântica
Executa uma operação de tangente hiperbólica com elementos no tensor operand
e
gera um tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para pontos flutuantes:
tanh
do IEEE-754. - Para números complexos: tangente hiperbólica complexa.
- Para tipos quantizados:
dequantize_op_quantize(tanh, operand, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transpor
Semântica
Altera as dimensões do tensor operand
usando permutation
e produz um
tensor result
. Mais formalmente, result[result_index] = operand[operand_index]
,
em que result_index[d] = operand_index[permutation[d]]
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor ou quantizado | (C1 a C4) |
(I2) | permutation |
constante do tensor unidimensional do tipo si64 |
(C2 a C4) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor ou quantizado | (C1) e C3 a C4. |
Restrições
- (C1)
element_type(result)
é fornecido por:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, exceto quequantization_dimension(operand)
equantization_dimension(result)
podem ser diferentes.
- (C2)
permutation
é uma permutação derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Se
is_per_axis_quantized(result)
, entãoquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Exemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semântica
Resolve lotes de sistemas de equações lineares com matrizes de coeficiente triangulares inferiores ou superiores.
Mais formalmente, considerando a
e b
, result[i0, ..., iR-3, :, :]
é a solução
para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
quando left_side
é
true
ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
quando
left_side
é false
, resolvendo a variável x
em que op(a)
é determinado
por transpose_a
, que pode ser uma destas opções:
NO_TRANSPOSE
: executa a operação usandoa
no estado em que se encontra.TRANSPOSE
: executa a operação na transposição dea
.ADJOINT
: executa a operação na transposição conjugada dea
.
Os dados de entrada serão lidos apenas no triângulo inferior de a
, se lower
for true
ou
no triângulo superior de a
. Os dados de saída são retornados no mesmo triângulo.
Os valores no outro triângulo são definidos pela implementação.
Se unit_diagonal
for verdadeiro, a implementação poderá presumir que os elementos diagonais
de a
são iguais a 1. Caso contrário, o comportamento será indefinido.
Para tipos quantizados, executa
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | a |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1 a C3) |
(I2) | b |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1 a C4) |
(I3) | left_side |
constante do tipo i1 |
(C3) |
(I4) | lower |
constante do tipo i1 |
|
(I5) | unit_diagonal |
constante do tipo i1 |
|
(I6) | transpose_a |
enumeração de NO_TRANSPOSE , TRANSPOSE e ADJOINT |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor | (C1) |
Restrições
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) A relação entre
shape(a)
eshape(b)
é definida da seguinte maneira:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Exemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semântica
Produz uma tupla result
a partir dos valores val
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | val |
número variável de valores | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tuple | (C1) |
Restrições
- (C1)
result
tem o tipotuple<E0, ..., EN-1>
, em queEi = type(val[i])
.
Exemplos
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semântica
Executa a conversão em elementos do tensor quantizado operand
em um
tensor de ponto flutuante result
, de acordo com os parâmetros de quantização definidos
pelo tipo operand
.
Mais formalmente, result = dequantize(operand)
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor quantizado | (C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor do tipo de ponto flutuante | (C1) e (C2). |
Restrições
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Exemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semântica
Executa a conversão com elementos do tensor de ponto flutuante ou quantizado operand
para um tensor quantizado result
de acordo com os parâmetros de quantização definidos pelo tipo result
.
Mais formalmente,
- Se
is_float(operand)
:result = quantize(operand, type(result))
.
- Se
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
tensor de ponto flutuante ou quantizado | (C1) e (C2). |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor quantizado | (C1) e (C2). |
Restrições
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Exemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
enquanto
Semântica
Produz a saída da execução da função body
0 ou mais vezes, enquanto a
função cond
gera true
. Mais formalmente, a semântica pode ser expressa usando a sintaxe do Python da seguinte maneira:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
O comportamento de um loop infinito ainda será definido (#383).
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | operand |
número variado de tensores, tensores quantizados ou tokens | (C1 a C3) |
(I2) | cond |
função | (C1) |
(I3) | body |
função | (C2) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
results |
número variado de tensores, tensores quantizados ou tokens | (C3) |
Restrições
- (C1)
cond
tem o tipo(T0, ..., TN-1) -> tensor<i1>
, em queTi = type(operand[i])
. - (C2)
body
tem o tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, em queTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Exemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Semântica
Executa XOR com elemento de dois tensores lhs
e rhs
e produz um
tensor result
. Dependendo do tipo de elemento, faça o seguinte:
- Para booleanos: XOR lógico.
- Para números inteiros: XOR bit a bit.
Entradas
Rótulo | Nome | Tipo | Restrições |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano ou inteiro | (C1) |
(I2) | rhs |
tensor de tipo booleano ou inteiro | (C1) |
Saídas
Nome | Tipo | Restrições |
---|---|---|
result |
tensor de tipo booleano ou inteiro | (C1) |
Restrições
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperabilidade do dialeto
No momento, os programas do StableHLO à solta às vezes contêm operações que não são definidas pelo StableHLO.
Módulo, função, chamada e retorno
O StableHLO usa operações MLIR upstream para ModuleOp, FuncOp, CallOp e ReturnOp. Isso foi feito para melhorar a interoperabilidade com o maquinário MLIR, já que muitos passagens úteis são gravadas visando FuncOp e ModuleOp, e muitos pipelines de compilação esperam que essas operações estejam presentes. Garantias de compatibilidade total são aplicadas a essas operações. Se algo mudar nas operações de maneira incompatível (ou seja, remoção), os equivalentes do StableHLO serão adicionados para preservar a compatibilidade.
CLO
O opset CHLO contém operações de nível superior que são decompostas em StableHLO. No momento, não há garantias de compatibilidade para CHLO. Para garantir a compatibilidade, o passe chlo-legalize-to-stablehlo (link em inglês) precisa ser usado antes da serialização.
Operações de forma
É um caso de uso comum na comunidade o uso de certas operações dos dialetos MLIR principais em programas StableHLO dinâmicos para realizar cálculos de formas.
Geralmente, isso inclui operações de dialeto shape
, como shape_of
ou num_elements
, operações de dialeto tensor
como dim
ou from_elements
e o tipo index
integrado.
O Dynamism RFC > O2
indica que eles estão fora do escopo, mas algum suporte para os tipos index
está
incluído para fins de interoperabilidade. Não há garantias de compatibilidade para essas operações ou tipos. A passagem shape-legalize-to-stablehlo
pode ser usada para converter essas operações em operações StableHLO totalmente compatíveis.
Operações descontinuadas
Há várias operações do StableHLO herdadas do MHLO (link em inglês), que foram descontinuadas e estão prestes a sair do StableHLO. Todos os detalhes sobre essas remoções podem ser encontrados no StableHLO v1.0 Cleanup #2283 (link em inglês). O problema do rastreador para essas descontinuações é #2340.
Essas operações se enquadram em algumas categorias:
- Categoria "Not in HLO" de operações StableHLO. Inicialmente, elas faziam parte da
opset StableHLO, mas foram considerados mais recentes:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
eunary_einsum
(no 3). - Operações não usadas: essas operações podem ter sido úteis em algum momento, mas as operações
estavam subdesenvolvidas ou os pipelines que as usavam foram
refatorados para não exigir mais delas. Isso inclui
map
,tuple
(#598),get_tuple_element
,rng
, comparações decomplex
, #560 e convoluçãowindow_reversal
(#1181).
Algumas dessas operações podem ser removidas facilmente, já que podem ser expressas usando
as operações atuais (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) e serão removidas após a janela de compatibilidade
atual passar (6 meses). Outras ainda estão sendo analisadas para remoção (comparações einsum
, get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
, window_reversal
). Se o feedback da comunidade estiver pendente, essas operações serão removidas ou adicionadas à especificação com suporte total. Até que essas operações futuras sejam conhecidas, há garantia de apenas seis meses de compatibilidade.
Execução
Execução sequencial
Um programa StableHLO é executado fornecendo valores de entrada para a função main
e calculando valores de saída. Os valores de saída de uma função são calculados
executando o gráfico de operações com acesso root na operação return
correspondente.
A ordem de execução é definida pela implementação desde que esteja alinhada com o Dataflow, ou seja, se as operações forem executadas antes do uso. No StableHLO, todas as operações com efeitos colaterais consomem um token e produzem um token (vários tokens podem ser multiplexados em um token via after_all
). Portanto, a ordem de execução dos efeitos colaterais também é alinhada com o Dataflow. Por exemplo, no programa abaixo,
há duas ordens de execução possíveis: %0
→ %1
→ %2
→ return
e
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Mais formalmente, um processo StableHLO é uma combinação de:
1) um programa StableHLO, 2) status de operação (ainda não executado,
já executado) e 3) valores intermediários em que o processo está trabalhando.
O processo começa com valores de entrada para a função main
, avança pelo gráfico de operações que atualizam status de operações e valores intermediários e termina com valores de saída. Outras formalizações ainda serão definidas
(#484).
Execução paralela
Os programas StableHLO podem ser executados em paralelo, organizados em uma grade de processos 2D
de num_replicas
por num_partitions
, que tem o tipo ui32
.
Na grade de processos do StableHLO, num_replicas * num_partitions
de processos do StableHLO
são executados ao mesmo tempo. Cada processo tem um
process_id = (replica_id, partition_id)
exclusivo, em que
replica_id
em replica_ids = range(num_replicas)
e
partition_id
em partition_ids = range(num_partitions)
, ambos com
tipo ui32
.
O tamanho da grade de processos é conhecido estaticamente para cada programa. No
futuro, planejamos torná-lo uma parte explícita dos programas StableHLO
#650, e a posição
dentro da grade de processos é conhecida estaticamente para todos os processos. Cada processo tem
acesso à própria posição na grade de processos por meio das operações replica_id
e
partition_id
.
Dentro da grade de processos, os programas podem ser todos iguais (no estilo "Programa único, vários dados"), podem ser diferentes (no estilo "Vários programas, vários dados") ou algo entre eles. No futuro, planejamos incluir suporte a outras expressões idiomáticas de definição de programas StableHLO paralelos, incluindo o GSPMD (#619).
Dentro da grade de processos, os processos são praticamente independentes entre si. Eles têm status de operação separados, valores de entrada/intermediário/saída separados e a maioria das operações é executada separadamente entre processos, com a exceção de um pequeno número de operações coletivas descritas abaixo.
Considerando que a execução da maioria das operações usa apenas valores do mesmo processo, geralmente não há ambiguidade em se referir a esses valores pelos nomes.
No entanto, ao descrever semântica de operações coletivas, isso é insuficiente e
gera a notação name@process_id
para se referir ao valor name
em um processo específico. Nessa perspectiva, name
não qualificado pode ser visto como uma abreviação de name@(replica_id(), partition_id())
.
A ordem de execução nos processos é definida pela implementação, exceto pela sincronização introduzida pela comunicação ponto a ponto e pelas operações coletivas, conforme descrito abaixo.
Comunicação ponto a ponto
Os processos do StableHLO podem se comunicar uns com os outros por meio de canais StableHLO. Um canal é representado por um ID positivo do tipo
si64
. Por meio de várias operações, é possível enviar valores a canais e recebê-los de canais.
Ainda vão ser definidas para formalização, como de onde esses IDs de canal vêm, como os programas se tornam cientes deles e que tipo de sincronização é introduzido por eles (#484).
Comunicação por streaming
Cada processo do StableHLO tem acesso a duas interfaces de streaming:
- Infeed que possam ser lidos.
- Saída para gravação.
Ao contrário dos canais, que são usados para comunicação entre processos e, portanto, têm processos em ambas as extremidades, as entradas e saídas têm a outra implementação definida.
Ainda a formalização, por exemplo, como a comunicação por streaming influencia a ordem de execução e que tipo de sincronização é introduzida por ela, ainda vai ser definida (#484).
Operações coletivas
Há seis operações coletivas no StableHLO: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
e
reduce_scatter
. Todas essas operações dividem os processos na grade de processos do StableHLO em grupos de processos do StableHLO e executam um cálculo conjunto dentro de cada grupo de processos, independentemente de outros grupos de processos.
Dentro de cada grupo de processos, as operações coletivas podem introduzir uma barreira de sincronização. Outras formalizações, por exemplo, elaborar sobre quando exatamente essa sincronização acontece, como os processos chegam a essa barreira e o que acontece se não acontecerem, será definida (#484).
Se o grupo de processos envolver comunicação entre partições, ou seja, houver
processos no grupo com IDs de partição diferentes, a execução
da operação coletiva precisará de um canal, e a operação coletiva precisará fornecer um
channel_id
positivo do tipo si64
. A comunicação entre réplicas não precisa de canais.
Os cálculos realizados pelas operações coletivas são específicos para operações individuais e são descritos nas seções de operação individuais acima. No entanto, as estratégias pelas quais a grade de processos é dividida em grupos de processos são compartilhadas entre essas operações e estão descritas nesta seção. Mais formalmente, o StableHLO oferece suporte às quatro estratégias a seguir.
cross_replica
Apenas as comunicações entre réplicas acontecem dentro de cada grupo de processos. Essa
estratégia usa replica_groups
, uma lista de listas de IDs de réplica, e calcula
um produto cartesiano de replica_groups
por partition_ids
. replica_groups
precisa ter elementos exclusivos e abranger todos os replica_ids
. Mais formalmente, usando a sintaxe do Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
, cross_replica
produzirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Apenas comunicações entre partições acontecem dentro de cada grupo de processos. Essa
estratégia usa partition_groups
, uma lista de listas de IDs de partição, e
calcula um produto cartesiano de partition_groups
por replica_ids
.
partition_groups
precisa ter elementos exclusivos e abranger todos os partition_ids
.
Mais formalmente, usando a sintaxe do Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para partition_groups = [[0, 1]]
e num_replicas = 4
, cross_partition
produzirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
As comunicações entre réplicas e partições podem ocorrer dentro de cada grupo de processos. Essa estratégia usa replica_groups
, uma lista de listas de
IDs de réplica, e calcula produtos cartesianos de cada replica_group
por
partition_ids
. replica_groups
precisa ter elementos exclusivos e abranger todos os
replica_ids
. Mais formalmente, usando a sintaxe do Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
, cross_replica_and_partition
produzirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Essa estratégia usa flattened_id_groups
, uma lista de listas de IDs de processos "nivelados"
na forma de replica_id * num_partitions + partition_id
, e
as transforma em IDs de processo. flattened_id_groups
precisa ter elementos exclusivos
e abranger todos os process_ids
. Mais formalmente, usando a sintaxe do Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por exemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
e num_partitions = 2
, flattened_ids
vai produzir
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Precisão
No momento, o StableHLO não oferece garantias sobre precisão numérica, mas isso pode mudar no futuro (#1156).
Semântica de execução da operação quantizada
A interpretação das operações StableHLO quantizadas pode variar dependendo dos requisitos e recursos de hardware. Por exemplo, alguns hardwares podem interpretar operações quantizadas usando uma estratégia de "desquantização, execução de operação de ponto flutuante e, por fim, quantização". Outros podem fazer o cálculo completo com aritmética de números inteiros. Consequentemente, a interpretação das operações StableHLO quantizadas é determinada exclusivamente pela implementação específica. A interpretação da quantização híbrida (#1575) precisa ser baseada na semântica dela, conforme prescrita na especificação (em 1792).
Erros
Os programas StableHLO são validados por um amplo conjunto de restrições para operações individuais, o que exclui muitas classes de erros antes do ambiente de execução. No entanto, condições de erro ainda são possíveis, por exemplo, por estouro de números inteiros, acessos fora dos limites etc. A menos que explicitamente chamados, todos esses erros resultam em um comportamento definido pela implementação, mas isso pode mudar no futuro (#1157).
Exceções de ponto flutuante
Como exceção a essa regra, as exceções de ponto flutuante em programas StableHLO
têm um comportamento bem definido. As operações que resultam em exceções definidas pelo
padrão IEEE-754 (operação inválida, divisão por zero, estouro, subfluxo ou
exceções imprecisas) produzem resultados padrão (conforme definido no padrão) e
continuam a execução sem gerar a flag de status correspondente, semelhante ao
processamento de exceções raiseNoFlag
do padrão. As exceções para operações
fora do padrão (por exemplo, aritmética complexa e determinadas funções transcendentais) são
definidas pela implementação.
Incompatibilidade de formas
O StableHLO oferece suporte a tensores de formato dinâmico. No entanto, as formas precisam estar de acordo no momento da execução. Caso contrário, o comportamento será indefinido. O StableHLO não fornece explicitamente uma operação que possa declarar que um tensor tem uma determinada forma no momento da execução. Gerar o código correto é responsabilidade do produtor.
Como exemplo específico, o programa abaixo é válido. No entanto, no momento da execução, as
formas exatas de %arg0
e %arg1
precisarão ser as mesmas. Caso contrário, o
comportamento do programa será indefinido:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Para descrever a sintaxe, este documento está usando a variação ISO modificada da sintaxe EBNF (ISO/IEC 14977:1996, Wikipédia), com duas modificações: 1) as regras são definidas usando ::=
em vez de =
.
2) A concatenação é expressa usando justaposição em vez de ,
.
Para descrever a semântica, ou seja, nas seções "Tipos", "Constantes" e "Operações", usamos fórmulas baseadas na sintaxe do Python estendida com suporte para expressar de maneira concisa as operações de matriz, conforme descrito abaixo. Isso funciona bem para pequenos snippets de código, mas em casos raros, quando são necessários snippets maiores, usamos a sintaxe Python básica, sempre introduzida explicitamente.
fórmulas
Vamos explorar como as fórmulas funcionam com base em um exemplo da especificação
dot_general
. Uma das restrições para essa operação tem esta aparência:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Os nomes usados nessa fórmula vêm de duas fontes: 1) funções globais,
ou seja, dim
; 2) definições de membro do elemento do programa correspondente, ou seja, entradas
lhs
, lhs_batching_dimensions
, rhs
e rhs_batching_dimensions
definidas na seção "Entradas" de dot_general
.
Como mencionado acima, a sintaxe dessa fórmula é baseada em Python com algumas extensões orientadas para concisão. Para dar sentido à fórmula, vamos transformá-la na sintaxe básica do Python.
A) Nessas fórmulas, estamos usando =
para representar a igualdade. Portanto, a primeira etapa para receber a sintaxe do Python é substituir =
por ==
, da seguinte maneira: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Além disso, essas fórmulas são compatíveis com elipses (...
), que transformam expressões escalares
em expressões de tensor. Em poucas palavras, f(xs...)
significa "para cada x
escalar no tensor xs
, calcule um f(x)
escalar e, em seguida, retorne todos esses resultados escalares juntos como um resultado do tensor". Na sintaxe básica do Python,
nossa fórmula de exemplo se transforma em:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Graças às elipses, muitas vezes é possível evitar trabalhar no nível de
elipses individuais. No entanto, em alguns casos complicados, a sintaxe semiinformal
de nível mais baixo pode ser usada como na fórmula start_indices[bi0, ..., :, ..., biN]
da especificação gather
. A serviço da concisão, não fornecemos um formalismo exato para traduzir essa sintaxe para Python baunilha, na esperança de que ela ainda seja intuitivamente compreensível caso a caso.
Informe-nos se algumas fórmulas específicas parecerem opacas, e tentaremos aprimorá-las.
Além disso, as fórmulas usam reticências para expandir todos os tipos de listas, incluindo tensores, listas de tensores (que, por exemplo, podem surgir de um número variado de tensores) etc. Essa é outra área em que não fornecemos um formalismo exato (por exemplo, as listas nem fazem parte do sistema de tipo StableHLO) e, em vez disso, confiamos na compreensão intuitiva.
C) O último veículo de notação notável que empregamos é a transmissão implícita. Embora a opset de StableHLO não ofereça suporte à transmissão implícita, as fórmulas são, também para concisão. Em resumo, se um escalar for usado em um contexto em que um tensor é esperado, ele será transmitido para a forma esperada.
Para continuar o exemplo de dot_general
, aqui está outra restrição:
0 <= lhs_batching_dimensions < rank(lhs)
. Conforme definido na especificação dot_general
, lhs_batching_dimensions
é um tensor. No entanto, 0
e rank(lhs)
são escalares. Depois de aplicar a transmissão implícita, a fórmula vai se tornar
[0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Quando aplicada a uma operação dot_general
específica, essa fórmula será avaliada como um tensor de booleanos. Quando fórmulas são usadas como restrições, a
restrição será mantida se a fórmula for avaliada como true
ou para um tensor que
tenha apenas elementos true
.
Nomes
Em fórmulas, o escopo lexical inclui: 1) funções globais, 2) definições de membro,
3) definições locais. Confira abaixo a lista de funções globais. A lista de definições de elementos depende do elemento do programa em que a notação é aplicada:
- Para operações, as definições de membros incluem nomes introduzidos nas seções "Entradas" e "Saídas".
- Para todo o restante, as definições de membro incluem partes estruturais do elemento do programa, nomeados de acordo com os não terminais de EBNF. Na maioria
das vezes, os nomes dessas partes estruturais são obtidos ao converter os
nomes dos não terminais em Snake Case (por exemplo,
IntegerLiteral
=>integer_literal
), mas às vezes os nomes são abreviados no processo (por exemplo,QuantizationStorageType
=>storage_type
). Nesse caso, os nomes são introduzidos explicitamente de maneira explícita às seções "Entradas" / "Saídas" nas especificações de operação. - Além disso, as definições de membro sempre incluem
self
para se referir ao elemento do programa correspondente.
Valores
Quando as fórmulas são avaliadas, elas funcionam com os seguintes tipos de valores:
1) Value
(valores reais, por exemplo, dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
elas sempre sabem os tipos),
2) Placeholder
(valores futuros, por exemplo, lhs
, rhs
ou result
; os valores reais
ainda não são conhecidos, somente seus tipos são conhecidos),
3) Type
(tipos conforme definido na seção "Tipos"),
4) Function
(funções globais, conforme definido na seção "Funções").
Dependendo do contexto, os nomes podem estar se referindo a valores diferentes. Mais
especificamente, a seção "Semântica" para operações (e equivalentes para outros elementos
do programa) define a lógica do ambiente de execução. Portanto, todas as entradas estão disponíveis como Value
.
Em contraste, a seção "Restrições" para operações (e equivalentes) define
a lógica de "tempo de compilação", ou seja, algo que normalmente é executado antes do tempo de execução.
Por isso, apenas entradas constantes estão disponíveis como Value
, e outras entradas estão
disponíveis apenas como Placeholder
.
Nomes | Em "Semântica" | Em "Restrições" |
---|---|---|
Funções globais | Function |
Function |
Entradas constantes | Value |
Value |
Entradas não constantes | Value |
Placeholder |
Saídas | Value |
Placeholder |
Definições locais | Depende da definição | Depende da definição |
Vamos considerar um exemplo de operação transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Para essa operação, permutation
é uma constante, por isso está disponível como um Value
em semântica e restrições. Por outro lado, operand
e result
estão
disponíveis como um Value
na semântica, mas apenas como um Placeholder
em restrições.
Funções
Construção de tipos
Não há funções que possam ser usadas para construir tipos. Em vez disso, usamos diretamente a sintaxe de tipo porque ela geralmente é mais concisa. Por exemplo,
(tensor<E>, tensor<E>) -> (tensor<E>)
em vez de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funções em tipos
element_type
é definido nos tipos de tensor e nos tipos de tensores quantizados e retorna, respectivamente, a parteTensorElementType
ouQuantizedTensorElementType
doTensorType
ouQuantizedTensorType
correspondente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
é um atalho parais_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
é um atalho parais_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
verifica se o tipox
pode ser promovido para o tipoy
. Quandox
ey
foremQuantizedTensorElementType
s, a promoção será aplicada apenas aostorage_type
. Essa versão específica da promoção é usada atualmente no contexto do cálculo de redução (consulte o RFC para mais detalhes).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
é um atalho parais_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponível para todos os tipos. Por exemplo,is_float(x)
retornatrue
sex
for umFloatType
. Sex
for um valor ou marcador, essa função será um atalho parais_type_name(type(x))
.max_value(x: Type) -> Value
retorna o valor máximo de umaTensorElementType
. Sex
não for umTensorElementType
, retornaNone
.min_value(x: Type) -> Value
retorna o valor mínimo possível de umaTensorElementType
. Sex
não for umTensorElementType
, retornaNone
.member_name(x: Value | Placeholder | Type) -> Any
: disponível para todas as definições de membromember_name
de todos os tipos. Por exemplo,tensor_element_type(x)
retorna a parteTensorElementType
de umTensorType
correspondente. Sex
for um valor ou marcador, essa função será um atalho paramember_name(type(x))
. Sex
não for um tipo que tenha um membro apropriado, um valor ou um marcador desse tipo, retornaNone
.
Construção de valores
operation_name(*xs: Value | Type) -> Value
. Disponível para todas as operações. Por exemplo,add(lhs, rhs)
usa dois valores de tensorlhs
erhs
e retorna a saída da avaliação da operaçãoadd
com essas entradas. Em algumas operações, por exemplo,broadcast_in_dim
, os tipos das saídas são "suporte de carga", ou seja, necessários para avaliar uma operação. Nesse caso, a função usa esses tipos como argumentos.
Funções nos valores
Todos os operadores e funções do Python estão disponíveis. Por exemplo, as notação de assinatura e fracionamento do Python estão disponíveis para indexação em tensores, tensores quantizados e tuplas.
to_destination_type(x: Value, destination_type: Type) -> Value
é definido em tensores e retorna o valor convertido dex
com base emtype(x)
edestination_type
da seguinte maneira:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Há uma discussão inicial sobre a mesclagem das operações convert
, uniform_quantize
e
uniform_dequantize
(#1576).
Após a mesclagem, a função acima não é necessária e podemos usar o nome da operação
para convert
.
is_nan(x: Value) -> Value
é definido nos tensores e retornatrue
se todos os elementos dex
foremNaN
oufalse
, caso contrário. Sex
não for um tensor, retornaráNone
.is_sorted(x: Value) -> Value
é definido nos tensores e retornatrue
se os elementos dex
estiverem classificados em ordem crescente em relação à ordem lexicográfica crescente dos índices. Caso contrário, retornafalse
. Sex
não for um tensor, retornaráNone
.is_unique(x: Value) -> Value
é definido em tensores e retornatrue
sex
não tiver elementos duplicados oufalse
, caso contrário. Sex
não for um tensor, retornaráNone
.member_name(x: Value) -> Any
é definido para todas as definições de membromember_name
de todos os valores. Por exemplo,real_part(x)
retorna a parteRealPart
de umComplexConstant
correspondente. Sex
não for um valor que tenha um membro apropriado, retornaNone
.same(x: Value) -> Value
é definido nos tensores e retornatrue
se os elementos dex
forem todos iguais ou se os elementos defalse
forem iguais. Se o tensor não tiver elementos, isso conta como "todos iguais entre si", ou seja, a função retornatrue
. Sex
não for um tensor, retornaNone
.split(x: Value, num_results: Value, axis: Value) -> Value
é definido em tensores e retorna fatiasnum_results
dex
ao longo do eixoaxis
. Sex
não for um tensor oudim(x, axis) % num_results != 0
, retornaNone
.is_defined_in_parent_scope(x: Value) -> Value
é definido em strings e retornatrue
sex
for o nome de uma função definida no mesmo escopo que a função pai da op relevante.is_namespaced_op_name(x: Value) -> Value
é definido em strings e retornatrue
sex
for um nome de operação válido, ou seja, ele respeita a seguinte expressão regular:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Cálculos de formas
axes(x: Value | Placeholder | Type) -> Value
é um atalho pararange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
é um atalho parashape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
é um atalho paralist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
é definido em tensores e retorna índicessize(x)
para oTensorType
correspondente classificado em ordem lexicográfica crescente, ou seja,[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Sex
não for um tipo de tensor, um tipo de tensor quantizado, um valor ou um marcador de posição de um desses tipos, retornaNone
.rank(x: Value | Placeholder | Type) -> Value
é um atalho parasize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
é definido na seção "Funções em tipos" viamember_name
.size(x: Value | Placeholder | Type) -> Value
é um atalho parareduce(lambda x, y: x * y, shape(x))
.
Computações de quantização
def baseline_element_type(x: Value | Placeholder | Type) -> Type
é um atalho paraelement_type(baseline_type(x))
.baseline_type
é definido em tipos de tensores e tipos de tensores quantizados e os transforma em um "valor de referência", ou seja, um tipo com a mesma forma, mas com os parâmetros de quantização do tipo de elemento redefinidos para os valores padrão. Isso é usado como um truque útil para comparar os tipos de tensor e quantizados de maneira uniforme, o que é necessário com bastante frequência. Para tipos quantizados, isso permite comparar tipos ignorando os parâmetros de quantização, ou seja,shape
,storage_type
,expressed_type
,storage_min
,storage_max
equantization_dimension
(para o tipo quantizado por eixo) precisam ser correspondentes, masscales
ezero points
podem ser diferentes.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
é definido em tipos de tensores quantizados e os transforma em tipos de tensores de ponto flutuante. Isso acontece pela conversão de elementos quantizados, que representam valores inteiros do tipo de armazenamento em valores de ponto flutuante correspondentes do tipo expresso usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
é definido em tipos de tensores de ponto flutuante e os transforma em tipos de tensores quantizados. Isso acontece por meio da conversão de valores de ponto flutuante do tipo expresso em valores inteiros correspondentes do tipo de armazenamento usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
é usado para especificar cálculos com elementos em tensores quantizados. Ela desquantiza, ou seja, transforma elementos quantizados nos tipos expressos, executa uma operação e quantiza, ou seja, transforma os resultados de volta nos tipos de armazenamento. No momento, essa função só funciona para quantização por tensor. A quantização por eixo está em desenvolvimento (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
é usado para especificar a quantização somente de peso para operações híbridas que aceita lhs em ponto flutuante e rhs em tipos quantizados. Ele desquantiza entradas quantizadas nos tipos expressos delas e realiza cálculos em flutuação. O tipo de elemento de tensor lhs de flutuação e o tipo expresso de tensor rhs quantizado precisam ser idênticos.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Computação em grade
cross_partition(replica_groups: Value) -> Value
. Veja a seção "cross_replica" acima.cross_replica(replica_groups: Value) -> Value
. Veja a seção "cross_replica" acima.cross_replica_and_partition(replica_groups: Value) -> Value
. Veja a seção "cross_replica_and_partition" acima.flattened_ids(replica_groups: Value) -> Value
. Consulte a seção "Flated_ids" acima.
Dinamismo
Os valores de StableHLO podem ter tamanhos de dimensão dinâmicos, como tensor<?xi64>
.
No entanto, os valores do StableHLO não podem ter um número dinâmico de dimensões (dinamismo
não classificado, por exemplo, tensor<*xi64>
). Os operandos e resultados podem usar tamanhos de dimensão
dinâmicos, mesmo se houver restrições nos tamanhos. As restrições serão
verificadas estaticamente, se possível. Caso contrário, elas serão adiadas para o ambiente de execução e
as incompatibilidades resultarão em um comportamento indefinido. Consulte os exemplos abaixo.
Incompatibilidade de formas em operações elementares de unário
Considere o seguinte programa de brinquedos:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Esse programa é incomum, porque não é comum conhecer a forma do resultado, mas não a forma da entrada. No entanto, esse é um programa StableHLO
válido. Não é possível validar estaticamente a operação abs
neste
programa porque a forma exata do operando é desconhecida. No entanto, as formas
certamente são compatíveis, e isso pode ser verificado estaticamente: ?
pode acabar sendo
2
no momento da execução, e não há problema. No entanto, ?
também pode ser
algum outro número inteiro. Nesse caso, o comportamento é indefinido.
Se um tamanho de dimensão for dinâmico no resultado, não poderá haver comportamento indefinido. De fato, não há um tamanho "esperado", portanto, não pode haver uma incompatibilidade.
Incompatibilidade de formas em operações binárias com elementos
Considere o seguinte programa de brinquedos:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Quando se trata de operações binárias com elementos, as formas das entradas e o resultado precisam ser iguais no momento da execução. Durante a compilação, as dimensões estáticas precisam ser iguais. Caso contrário, elas só precisarão ser compatíveis. Se qualquer dimensão for dinâmica nas entradas, poderá haver comportamento indefinido no momento da execução, já que o tamanho dinâmico pode não corresponder ao tamanho correspondente no outro operando (seja estático ou dinâmico). Se todas as entradas forem estáticas, o resultado será dinâmico ou não será relevante. As dimensões conhecidas estaticamente serão verificadas estaticamente, e as dinâmicas não imporão nenhuma restrição.
Incompatibilidades de forma para operações que assumem a forma de saída como um operando
Considere o seguinte programa de brinquedos:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Os valores no operando da forma no momento da execução precisam corresponder à forma do resultado.
Caso contrário, o comportamento será indefinido. Ou seja, no momento da execução, %arg0
precisa ter um
valor dense<[3, 4]> : tensor<2xi32>
. Se o operando de forma for constante, isso
poderá ser verificado estaticamente. Se o formato do resultado for totalmente dinâmico, não
pode haver incompatibilidade.