StableHLO è un insieme di operazioni per operazioni di alto livello (HLO) nei modelli di machine learning (ML). StableHLO funziona come un livello di portabilità tra diversi framework ML e compilatori ML: i framework ML che producono programmi StableHLO sono compatibili con i compilatori ML che utilizzano programmi StableHLO.
Il nostro obiettivo è semplificare e accelerare lo sviluppo ML creando una maggiore interoperabilità tra vari framework ML (come TensorFlow, JAX e PyTorch) e compilatori ML (come XLA e IREE). A tal fine, questo documento fornisce una specifica per il linguaggio di programmazione StableHLO.
Questa specifica contiene tre sezioni principali. Innanzitutto, la sezione Programmi descrive la struttura dei programmi StableHLO, costituiti da funzioni StableHLO, a loro volta composte da operazioni StableHLO. All'interno di questa struttura, la sezione Ops specifica la semantica delle singole operazioni. La sezione Esecuzione fornisce la semantica per tutte queste operazioni eseguite insieme all'interno di un programma. Infine, la sezione Notazione illustra la notazione utilizzata nella specifica.
Programmi
Program ::= {Func}
I programmi StableHLO sono costituiti da un numero arbitrario di funzioni StableHLO.
Di seguito è riportato un programma di esempio con una funzione @main
che ha 3 input
(%image
, %weights
e %bias
) e 1 output. Il corpo della funzione
ha 6 operazioni.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funzioni
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Le funzioni StableHLO (chiamate anche funzioni con nome) hanno un identificatore, input/output e un corpo. In futuro, abbiamo in programma di introdurre metadati aggiuntivi per le funzioni al fine di ottenere una migliore compatibilità con HLO (#425, #626, #740, #744).
Identificatori
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Gli identificatori StableHLO sono simili a quelli di molti linguaggi di programmazione, con due peculiarità: 1) tutti gli identificatori hanno sigil che distinguono diversi tipi di identificatori, 2) gli identificatori dei valori possono essere completamente numerici per semplificare la generazione di programmi StableHLO.
Tipi
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
I tipi StableHLO sono classificati in tipi di valore (chiamati anche tipi di prima classe) che rappresentano i valori StableHLO e tipi non valore che descrivono altri elementi del programma. I tipi StableHLO sono simili a quelli di molti linguaggi di programmazione, con la principale peculiarità della natura specifica del dominio di StableHLO, che genera alcuni risultati insoliti (ad esempio, i tipi scalari non sono tipi di valori).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
I tipi di Tensor rappresentano tensori, ovvero array multidimensionali. Hanno una forma e un tipo di elemento, dove una forma rappresenta le dimensioni non negative nell'ordine crescente delle dimensioni corrispondenti (chiamate anche assi) numerate da 0
a R-1
. Il numero di dimensioni R
è chiamato ranking. Ad esempio, tensor<2x3xf32>
è un tipo tensore con forma 2x3
e tipo di elemento f32
. Presenta due dimensioni (o, in altre parole, due assi), la 0° dimensione e la 1° dimensione, le cui dimensioni sono 2 e 3. Il suo ranking è 2.
Definisce il supporto per forme statiche in cui le dimensioni delle dimensioni sono note in modo statico. In futuro, abbiamo in programma di introdurre il supporto anche per le forme dinamiche in cui le dimensioni delle dimensioni sono parzialmente o completamente sconosciute (n. 8). Inoltre, prevediamo di esplorare l'estensione dei tipi di tensori oltre le dimensioni delle dimensioni e i tipi di elementi, ad esempio per includere layout (#629) e sparsità (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nome | Tipo | Vincoli |
---|---|---|
storage_type |
tipo intero | (C1-C4), (C9) |
storage_min |
costante intera | (C2), (C4), (C8) |
storage_max |
costante intera | (C3), (C4), (C8) |
expressed_type |
tipo a virgola mobile | (C1), (C5) |
quantization_dimension |
costante intera facoltativa | (C11-C13) |
scales |
numero variadico di costanti in virgola mobile | (C5-C7), (C10), (C11), (C13) |
zero_points |
numero variadico di costanti numeriche | (C8-C10) |
I tipi di elementi quantizzati rappresentano i valori interi di un tipo di archiviazione nell'intervallo da storage_min
a storage_max
(incluso) che corrispondono a valori in virgola mobile di un tipo espresso. Per un determinato valore intero i
, il corrispondente valore in virgola mobile f
può essere calcolato come f = (i - zero_point) * scale
, dove scale
e zero_point
sono chiamati parametri di quantizzazione. storage_min
e storage_max
sono facoltativi nella grammatica, ma hanno rispettivamente valori predefiniti di min_value(storage_type)
e max_value(storage_type)
. I tipi di elementi quantizzati hanno i seguenti vincoli:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
. - (C2)
type(storage_min) = storage_type
. - (C3)
type(storage_max) = storage_type
. - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C5)
type(scales...) = expressed_type
. - (C6)
0 < scales
. - (C7)
is_finite(scales...)
. - (C8)
storage_min <= zero_points <= storage_max
. - (C9)
type(zero_points...) = storage_type
. - (C10)
size(scales) = size(zero_points)
. - (C11) Se
is_empty(quantization_dimension)
, allorasize(scales) = 1
. - (C12)
0 <= quantization_dimension
.
Al momento QuantizationScale
è una costante in virgola mobile, ma c'è un forte interesse per le scale basate su numeri interi, rappresentate con moltiplicatori e spostamenti. Abbiamo in programma di effettuare questa valutazione nel prossimo futuro (#1404).
È in corso una discussione sulla semantica di QuantizationZeroPoint
, compresi il tipo, i valori e se possono esserci solo uno o più punti zero in un tipo di tensore quantizzato. In base ai risultati di questa discussione, la specifica relativa a zero punti potrebbe cambiare in futuro (#1405).
Un'altra discussione in corso riguarda la semantica di QuantizationStorageMin
e QuantizationStorageMax
per determinare se è necessario
applicare vincoli a questi valori e ai valori dei tensori quantizzati
(#1406).
Infine, stiamo pianificando di esplorare la rappresentazione di scale e punti zero sconosciuti, in modo simile a come prevediamo di esplorare la rappresentazione di dimensioni delle dimensioni sconosciute (#1407).
I tipi di tensori quantizzati rappresentano tensori con elementi quantizzati. Questi tensori sono esattamente uguali ai tensori regolari, tranne per il fatto che i loro elementi hanno tipi di elementi quantizzati, invece dei tipi di elementi regolari.
Nei tensori quantizzati, la quantizzazione può essere per tensore, ovvero avere
un scale
e zero_point
per l'intero tensore, oppure per asse,
ovvero avere più scales
e zero_points
, una coppia per sezione di
una determinata dimensione quantization_dimension
. Più formalmente, in un tensore t
con quantizzazione per asse, ci sono dim(t, quantization_dimension)
sezioni di quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
e così via. Tutti gli elementi nella i
a sezione utilizzano scales[i]
e zero_points[i]
come parametri di quantizzazione. I tipi di tensori quantizzati hanno i seguenti vincoli:
- Per la quantizzazione per tensore:
- Nessun vincolo aggiuntivo.
- Per la quantizzazione per asse:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
I tipi di token rappresentano i token, ovvero i valori opachi prodotti e utilizzati da alcune operazioni. I token vengono utilizzati per imporre l'ordine di esecuzione sulle operazioni, come descritto nella sezione Esecuzione.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
I tipi di tuple rappresentano tuple, ovvero elenchi eterogenei. Le tuple sono una funzionalità legacy che esiste solo per la compatibilità con HLO. In HLO, le tuple vengono
utilizzate per rappresentare input e output variadici. In StableHLO, gli input e gli output variadici sono supportati in modo nativo e l'unico utilizzo delle tuple in StableHLO è rappresentare in modo completo HLO ABI, dove ad esempio T
, tuple<T>
e tuple<tuple<T>>
possono essere materialmente diversi a seconda di una particolare implementazione. In futuro, abbiamo in programma di apportare modifiche ad HLO ABI,
che potrebbero consentirci di rimuovere i tipi di tuple da StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
I tipi di elementi rappresentano elementi di tipi tensori. A differenza di molti linguaggi di programmazione, questi tipi non sono di prima classe in StableHLO. Ciò significa che i programmi StableHLO non possono rappresentare direttamente valori di questi tipi (di conseguenza, è idiomatico rappresentare valori scalari di tipo T
con valori di tensori 0 dimensioni di tipo tensor<T>
).
- Il tipo booleano rappresenta i valori booleani
true
efalse
. - I tipi di numeri interi possono essere con segno (
si
) o senza firma (ui
) e avere una delle larghezze di bit supportate (4
,8
,16
,32
o64
). I tipisiN
firmati rappresentano valori interi compresi tra-2^(N-1)
e2^(N-1)-1
, mentre i tipiuiN
senza segno rappresentano valori interi compresi tra0
e2^N-1
. - I tipi con virgola mobile possono essere uno dei seguenti:
- I tipi
f8E4M3FN
ef8E5M2
corrispondono rispettivamente alle codificheE4M3
eE5M2
del formato FP8 descritto in Formati FP8 per il deep learning. - I tipi
f8E4M3FNUZ
ef8E5M2FNUZ
corrispondono alle codificazioniE4M3
eE5M2
dei formati FP8 descritti in Formati numerici a 8 bit per reti neurali profonde. f8E4M3B11FNUZ
corrispondente alla codificaE4M3
dei formati FP8 descritti in Addestramento e inferenza per reti neurali profonde a 8-bit Floating Point (HFP8) ibrido.- di tipo
bf16
corrispondente al formatobfloat16
descritto in BFloat16: il secret per ottenere prestazioni elevate sulle Cloud TPU. - I tipi
f16
,f32
ef64
corrispondono rispettivamente ai formatibinary16
("precisione a metà"),binary32
("precisione singola") ebinary64
("precisione doppia") descritti nello standard IEEE 754.
- I tipi
- I tipi complessi rappresentano valori complessi che hanno una parte reale e una parte immaginaria dello stesso tipo di elemento. I tipi complessi supportati sono
complex<f32>
(entrambe le parti sono di tipof32
) ecomplex<f64>
(entrambe le parti sono di tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
I tipi di funzione rappresentano funzioni sia con nome che anonime. Hanno tipi di input (l'elenco dei tipi sul lato sinistro di ->
) e tipi di output (l'elenco dei tipi a destra di ->
). In molti linguaggi di programmazione, i tipi di funzione sono di prima classe, ma non in StableHLO.
StringType ::= 'string'
Il tipo di stringa rappresenta le sequenze di byte. A differenza di molti linguaggi di programmazione, il tipo di stringa non è il primo in StableHLO e viene utilizzato solo per specificare metadati statici per gli elementi del programma.
Suite operativa
Le operazioni StableHLO (chiamate anche operazioni) rappresentano un insieme chiuso di operazioni di alto livello nei modelli di machine learning. Come discusso in precedenza, la sintassi di StableHLO è fortemente ispirata a MLIR, che non è necessariamente l'alternativa più ergonomica, ma è probabilmente la soluzione migliore per l'obiettivo di StableHLO di creare una maggiore interoperabilità tra framework ML e compilatori ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Le operazioni StableHLO (chiamate anche operazioni) hanno un nome, input/output e una firma. Il nome è composto dal prefisso stablehlo.
e da una mnemonica che identifica in modo univoco una delle operazioni supportate. Di seguito è riportato un elenco completo di tutte le operazioni supportate.
Al momento, i programmi StableHLO all'aperto a volte contengono operazioni non descritte in questo documento. In futuro, abbiamo in programma di assorbire queste operazioni nell'opset StableHLO o di impedirne la visualizzazione nei programmi StableHLO. Nel frattempo, ecco un elenco di queste operazioni:
builtin.module
,func.func
,func.call
efunc.return
(#425).chlo
operazioni (#602).- Categoria "Non in HLO" delle operazioni StableHLO. Inizialmente facevano parte dell'opset StableHLO, ma in seguito sono state ritenute non adatte:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Categoria "Dinamismo" delle operazioni StableHLO. Sono state sottoposte a bootstrapping
da MHLO, ma non le abbiamo ancora specificate:
compute_reshape_shape
,cstr_reshapable
,dynamic_broadcast_in_dim
,dynamic_conv
,dynamic_gather
,dynamic_iota
,dynamic_pad
,dynamic_reshape
,real_dynamic_slice
,set_dimension_size
(#8). - Calcoli delle forme, che includono le operazioni
arith
,shape
etensor
(#8).
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Il team operativo consuma gli input e produce output. Gli input sono classificati in valori di input (calcolati durante l'esecuzione), funzioni di input (fornite in modo statico, perché in StableHLO le funzioni non sono valori di prima classe) e attributi di input (forniti anche in modo statico). Il tipo di input e output
consumati e prodotti da un'operazione dipende dal suo mnemonico. Ad esempio, l'operazione add
consuma 2 valori di input e produce 1 valore di output. In confronto, l'operazione select_and_scatter
consuma 3 valori di input, 2 funzioni di input e 3 attributi di input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Le funzioni di input (chiamate anche funzioni anonime) sono molto
simili alle funzioni con nome, tranne per il fatto che: 1) non hanno un identificatore (da qui il nome "anonimo "), 2) non dichiarano i tipi di output (i tipi di output vengono dedotti dall'operazione return
all'interno della funzione).
La sintassi per le funzioni di input include una parte attualmente inutilizzata (consulta la produzione Unused
sopra), necessaria per la compatibilità con MLIR. In MLIR, esiste un concetto più generale di "regioni", che possono avere più "blocchi" di operazioni collegate tra loro tramite salti. Questi blocchi hanno ID che corrispondono alla produzione Unused
, in modo che possano essere distinti.
StableHLO non prevede operazioni di salto, quindi la parte corrispondente della sintassi MLIR è inutilizzata (ma è ancora presente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Gli attributi di input contengono un nome e un valore che rappresentano una delle costanti supportate. Sono il modo principale per specificare metadati statici per gli elementi del programma. Ad esempio, l'operazione concatenate
utilizza l'attributo dimension
per specificare la dimensione con cui sono concatenati i valori di input. In modo analogo, l'operazione slice
utilizza più attributi come start_indices
e limit_indices
per specificare i limiti utilizzati per suddividere il valore di input.
Al momento, i programmi StableHLO allo stato brado a volte contengono attributi non descritti in questo documento. In futuro, abbiamo in programma di assorbire questi attributi nell'opset StableHLO o di impedirne la visualizzazione nei programmi StableHLO. Nel frattempo, ecco un elenco di questi attributi:
layout
(n. 629).mhlo.frontend_attributes
(n. 628).mhlo.sharding
(n. 619).output_operand_aliases
(n. 740).- Metadati posizione (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma op è composta dai tipi di tutti i valori di input (l'elenco dei tipi sul lato sinistro di ->
) e dai tipi di tutti i valori di output (l'elenco dei tipi a destra di ->
). A rigore, i tipi di input sono ridondanti e anche i tipi di output sono quasi sempre ridondanti (perché per la maggior parte delle operazioni StableHLO, i tipi di output possono essere dedotti dagli input). Ciononostante, la firma dell'op
fa parte deliberatamente della sintassi StableHLO per garantire la compatibilità con MLIR.
Di seguito è riportato un esempio di op il cui mnemonico è select_and_scatter
. Consuma 3 valori di input (%operand
, %source
e %init_value
), 2 funzioni di input e 3 attributi di input (window_dimensions
, window_strides
e padding
). Nota come la firma dell'operazione include solo i tipi dei suoi valori di input (ma non i tipi di funzioni di input e attributi che sono forniti in linea).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Costanti
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Le costanti StableHLO hanno un valore letterale e un tipo che insieme rappresentano un valore StableHLO. In genere, il tipo fa parte della sintassi costante, tranne quando è inequivocabile (ad es. una costante booleana ha il tipo i1
inequivocabile, mentre una costante intera può avere più tipi possibili).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Le costanti booleane rappresentano i valori booleani true
e false
. Le costanti booleane hanno il tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Le costanti numeri interi rappresentano valori interi tramite stringhe che utilizzano la notazione decimale o esadecimale. Altre basi, ad esempio binarie o ottali, non sono supportate. Le costanti numeriche hanno i seguenti vincoli:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Le costanti a virgola mobile rappresentano valori con virgola mobile tramite stringhe che utilizzano la notazione decimale o scientifica. Inoltre, la notazione esadecimale può essere utilizzata per specificare direttamente i bit sottostanti nel formato a virgola mobile del tipo corrispondente. Le costanti a virgola mobile hanno i seguenti vincoli:
- (C1) Se viene utilizzata la notazione non esadecimale,
is_wellformed(float_literal, float_type)
. - (C2) Se viene utilizzata la notazione esadecimale,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Le costanti complesse rappresentano valori complessi utilizzando gli elenchi di una parte reale (viene per prima) e di una parte immaginaria (seconda). Ad esempio, (1.0, 0.0) : complex<f32>
rappresenta 1.0 + 0.0i
e (0.0, 1.0) : complex<f32>
rappresenta 0.0 + 1.0i
. L'ordine in cui le parti vengono archiviate in memoria è definito dall'implementazione. Le costanti complesse hanno i seguenti vincoli:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Le costanti Tensor rappresentano i valori di tensore utilizzando elenchi nidificati specificati tramite la notazione NumPy. Ad esempio, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
rappresenta un valore tensore con la seguente mappatura dagli indici agli elementi:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. L'ordine in cui questi elementi vengono archiviati in memoria è definito dall'implementazione. Le costanti tensore hanno i seguenti vincoli:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, dove:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, dove:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- altrimenti
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Le costanti dei tensori quantizzati rappresentano i valori dei tensori quantizzati utilizzando la stessa notazione delle costanti del tensore, con elementi specificati come costanti del loro tipo di archiviazione. Le costanti dei tensori quantizzati hanno i seguenti vincoli:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
I valori letterali stringa sono costituiti da byte specificati utilizzando caratteri ASCII e sequenze di escape. Poiché sono indipendenti dalla codifica, l'interpretazione
di questi byte è definita dall'implementazione. I valori letterali stringa hanno il tipo string
.
Operazioni
abs
Semantica
Esegue un'operazione di assembramento a livello di elemento sul tensore operand
e produce un tensore di result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri interi firmati: modulo intero.
- Per i numeri in virgola mobile:
abs
dallo standard IEEE-754. - Per i numeri complessi: modulo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(abs, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numero intero firmato, in virgola mobile, complesso o tensore quantizzato per tensore | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero firmato o in virgola mobile o tensore quantizzato per tensore | (C1-C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
è definito come:complex_element_type(element_type(operand))
seis_complex(operand)
.baseline_element_type(operand)
in caso contrario.
Esempi
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semantica
Esegue l'aggiunta a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: OR logico.
- Per i numeri interi: addizione di numeri interi.
- Per i numeri in virgola mobile:
addition
dallo standard IEEE-754. - Per i numeri complessi: addizione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantica
Garantisce che le operazioni che generano inputs
vengano eseguite prima di qualsiasi operazione che dipende da result
. L'esecuzione di questa operazione non ha alcun effetto,
esiste solo per stabilire dipendenze dei dati da result
a inputs
.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di token |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantica
All'interno di ogni gruppo di processo nella griglia di processo StableHLO, concatena i valori del tensore operand
di ogni processo lungo all_gather_dim
e produce un tensore result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Successivamente, entro ogni process_group
:
operands@receiver = [operand@sender for sender in process_group]
per tuttireceiver
inprocess_group
.result@process = concatenate(operands@process, all_gather_dim)
per tuttiprocess
inprocess_group
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C6) |
(I2) | all_gather_dim |
costante di tipo si64 |
(C1), (C6) |
(I3) | replica_groups |
Costante del tensore bidimensionale di tipo si64 |
(C2-C4) |
(I4) | channel_id |
costante di tipo si64 |
(C5) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C6) |
Vincoli
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
è definito come:num_replicas
se viene utilizzatocross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C6)
type(result) = type(operand)
tranne:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Semantica
All'interno di ogni gruppo di processo nella griglia di processo StableHLO, applica una funzione di riduzione computation
ai valori del tensore operand
di ciascun processo e produce un tensore result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Successivamente, entro ogni process_group
:
result@process[result_index] = exec(schedule)
per un albero binarioschedule
dove:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
è una struttura binaria definita dall'implementazione il cui attraversamento in ordine èto_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C5), (C6) |
(I2) | replica_groups |
numero variadico di costanti dei tensori unidimensionali di tipo si64 |
(C1-C3) |
(I3) | channel_id |
costante di tipo si64 |
(C4) |
(I4) | use_global_device_ids |
costante di tipo i1 |
(C4) |
(I5) | computation |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C6-C7) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
è definito come:num_replicas
se viene utilizzatocross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C5)
computation
ha il tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, doveis_promotable(element_type(operand), E)
. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Semantica
All'interno di ogni gruppo di processi nella griglia di processo StableHLO, suddivide i valori del tensore operand
lungo split_dimension
in parti, distribuisce le parti divise tra i processi, concatena le parti sparse insieme concat_dimension
e produce un tensore result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
Successivamente, entro ogni process_group
:
split_parts@sender = split(operand@sender, split_count, split_dimension)
per tutti isender
inprocess_group
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
dovereceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C3), (C9) |
(I2) | split_dimension |
costante di tipo si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
costante di tipo si64 |
(C3), (C9) |
(I4) | split_count |
costante di tipo si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Costante del tensore bidimensionale di tipo si64 |
(C5-C8) |
(I6) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C9) |
Vincoli
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
è definito come:num_replicas
se viene utilizzatocross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
tranne:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
e
Semantica
Esegue l'operatore AND per elementi di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: AND logico.
- Per i numeri interi: AND a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantica
Esegue un'operazione atan2 a livello di elemento su lhs
e rhs
tensore e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
atan2
dallo standard IEEE-754. - Per i numeri complessi: atan2 complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantica
Calcola i gradienti di diversi input di batch_norm_training
retropropagazione
da grad_output
e produce tensori grad_operand
, grad_scale
e grad_offset
. Più formalmente, questa operazione può essere espressa come una scomposizione di operazioni StableHLO esistenti utilizzando la sintassi Python, come segue:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Per i tipi quantizzati, esegue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C3), (C5) |
(I2) | scale |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4), (C5) |
(I3) | mean |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I4) | variance |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I5) | grad_output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
grad_operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
grad_scale |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
grad_offset |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
egrad_offset
hanno lo stessobaseline_element_type
. - (C3)
operand
,grad_output
egrad_operand
hanno la stessa forma. - (C4)
scale
,mean
,variance
,grad_scale
egrad_offset
hanno la stessa forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantica
Normalizza il tensore operand
in tutte le dimensioni tranne la dimensione feature_index
e produce un tensore result
. Più formalmente, questa operazione può essere espressa come una scomposizione di operazioni StableHLO esistenti utilizzando la sintassi Python, come segue:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Per i tipi quantizzati, esegue dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C7) |
(I2) | scale |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C3) |
(I3) | offset |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I4) | mean |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C5) |
(I5) | variance |
Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C6) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C7) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
eresult
hanno lo stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantica
Calcola la media e la varianza su tutte le dimensioni, ad eccezione della dimensione feature_index
, e normalizza il tensore operand
che produce i tensori output
, batch_mean
e batch_var
. Più formalmente, questa operazione può essere espressa come una scomposizione di operazioni StableHLO esistenti utilizzando la sintassi Python, come segue:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Per i tipi quantizzati, esegue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | scale |
Tensore unidimensionale di virgola mobile o per tensore quantizzato | (C2), (C3) |
(I3) | offset |
Tensore unidimensionale di virgola mobile o per tensore quantizzato | (C2), (C4) |
(I4) | epsilon |
costante di tipo f32 |
(C1), (C3-C6) |
(I5) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C7) |
batch_mean |
Tensore unidimensionale di virgola mobile o per tensore quantizzato | (C2), (C5) |
batch_var |
Tensore unidimensionale di virgola mobile o per tensore quantizzato | (C2), (C6) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
eoutput
hanno lo stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantica
Esegue un'operazione bitcast sul tensore operand
e produce un tensore result
in cui i bit dell'intero tensore operand
vengono reinterpretati utilizzando il tipo del tensore result
.
Più formalmente, dati E = element_type(operand)
, E' = element_type(result)
e R = rank(operand)
:
- Se
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Se
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Se
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
restituisce la rappresentazione in memoria di un determinato valore e il suo comportamento è definito dall'implementazione perché la rappresentazione esatta dei tensori è definita dall'implementazione e anche la rappresentazione esatta dei tipi di elementi è definita dall'implementazione.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o quantizzato | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o quantizzato | (C1-C2) |
Vincoli
- (C1) Dati
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
eR = rank(operand)
:- Se
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Se
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Se
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Se
- (C2) Se
is_complex(operand) or is_complex(result)
,is_complex(operand) and is_complex(result)
.
Esempi
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantica
Espande le dimensioni e/o il ranking di un tensore di input duplicando i dati
nel tensore operand
e produce un tensore result
. Più formalmente,
result[result_index] = operand[operand_index]
dove per tutti i d
in
axes(operand)
:
operand_index[d] = 0
sedim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o quantizzato | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C2-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o quantizzato | (C1), (C3), (C5-C6) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccetto chequantization_dimension(operand)
,scales(operand)
ezero_points(operand)
potrebbero differire daquantization_dimension(result)
,scales(result)
ezero_points(result)
rispettivamente.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Per tutti i
d
inaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se
dim(operand, quantization_dimension(operand)) = 1
, allorascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Esempi
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
richiesta
Semantica
Restituisce l'output dall'esecuzione di una sola funzione di branches
a seconda del valore di index
. A livello più formale, result = selected_branch()
dove:
selected_branch = branches[index]
se0 <= index < size(branches)
.selected_branch = branches[-1]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | index |
Tensore 0-dimensionale di tipo si32 |
|
(I2) | branches |
numero variadico di funzioni | (C1-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C4) |
Vincoli
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Esempi
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
Semantica
Esegue un'operazione di radice cubica per elemento sul tensore operand
e produce un
tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
rootn(x, 3)
dallo standard IEEE-754. - Per i numeri complessi: radice cubica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(cbrt, operand, type(result))
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantica
Esegue il Ceil per elemento del tensore operand
e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardPositive
dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(ceil, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantica
Calcola la scomposizione di Cholesky di un batch di matrici.
Più formalmente, per tutti gli i
in index_space(result)
,
result[i0, ..., iR-3, :, :]
è una decomposizione di Cholesky di
a[i0, ..., iR-3, :, :]
, sotto forma di una matrice triangolare inferiore
(se lower
è true
) o triangolare superiore (se lower
è false
).
I valori di output nel triangolo opposto, ovvero il triangolo superiore stretto o il triangolo inferiore stretto, sono definiti dall'implementazione.
Se esiste i
in cui la matrice di input non è una matrice hermitiana con definizione positiva, il comportamento è indefinito.
Per i tipi quantizzati, esegue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | lower |
Costante del tensore 0 dimensionale di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Esempi
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
clampare
Semantica
Collega ogni elemento del tensore operand
tra un valore minimo e un valore massimo e produce un tensore result
. A livello più formale, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
dove min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Per i tipi quantizzati,
esegue dequantize_op_quantize(clamp, min, operand, max, type(result))
.
L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | min |
tensore o tensore quantizzato per tensore | (C1), (C3) |
(I2) | operand |
tensore o tensore quantizzato per tensore | (C1-C4) |
(I3) | max |
tensore o tensore quantizzato per tensore | (C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C4) |
Vincoli
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantica
All'interno di ogni gruppo di processo nella griglia di processo StableHLO, invia il valore del tensore operand
dal processo di origine ai processi di destinazione e produci un tensore result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
In seguito, il valore result@process
sarà fornito da:
operand@process_groups[i, 0]
se esiste uni
in modo che il processo si trovi inprocess_groups[i]
.broadcast_in_dim(constant(0, element_type(result)), [], type(result))
altrimenti.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore | (C3) |
(I2) | replica_groups |
numero variadico di costanti dei tensori unidimensionali di tipo si64 |
(C1), (C2) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore | (C3) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, doveN
è definito come:num_replicas
se viene utilizzatocross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C3)
type(result) = type(operand)
.
Esempi
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantica
All'interno di ogni gruppo di processo nella griglia di processo StableHLO, invia il valore del tensore operand
dal processo di origine al processo di destinazione e produce un tensore result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(source_target_pairs)
sechannel_id <= 0
.cross_partition(source_target_pairs)
sechannel_id > 0
.
In seguito, il valore result@process
sarà fornito da:
operand@process_groups[i, 0]
, se esiste uni
cheprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
altrimenti.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C5) |
(I2) | source_target_pairs |
Costante del tensore bidimensionale di tipo si64 |
(C1-C4) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, doveN
è definito come:num_replicas
se viene utilizzatocross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C5)
type(result) = type(operand)
.
Esempi
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Semantica
Esegue un confronto a livello di elemento dei tensori lhs
e rhs
secondo comparison_direction
e compare_type
e produce un tensore result
.
I valori di comparison_direction
e compare_type
hanno la seguente
semantica:
Per i tipi di elementi booleani e interi:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Per i tipi di elementi in virgola mobile con compare_type = FLOAT
, l'operazione implementa le seguenti operazioni IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Per i tipi di elementi con virgola mobile con compare_type = TOTALORDER
, l'operazione utilizza la combinazione di operazioni totalOrder
e compareQuietEqual
da IEEE-754. Questa funzionalità sembra inutilizzata, perciò abbiamo intenzione di rimuoverla (#584).
Per i tipi di elementi complessi, il confronto lessicografico di coppie (real, imag)
viene eseguito utilizzando i valori comparison_direction
e compare_type
forniti.
L'imposizione di un ordinamento sui numeri complessi implica una semantica sorprendente,
quindi in futuro prevediamo di rimuovere il supporto per i numeri complessi
quando comparison_direction
è GE
, GT
, LE
o LT
(#560).
Per i tipi quantizzati, esegue dequantize_compare(lhs, rhs,
comparison_direction)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1-C3) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1-C2) |
(I3) | comparison_direction |
enum di EQ , NE , GE , GT , LE e LT |
|
(I4) | compare_type |
enum di FLOAT , TOTALORDER , SIGNED e UNSIGNED |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano | (C2) |
Vincoli
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
è definito come:SIGNED
seis_signed_integer(element_type(lhs))
.UNSIGNED
seis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
seis_float(element_type(lhs))
.FLOAT
seis_complex(element_type(lhs))
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complesso
Semantica
Esegue la conversione a livello di elemento in un valore complesso da una coppia di valori reali e immaginari, lhs
e rhs
, e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo f32 o f64 |
(C1-C3) |
(I2) | rhs |
tensore di tipo f32 o f64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo complesso | (C2), (C3) |
Vincoli
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
ha il tipocomplex<E>
, doveE = element_type(lhs)
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
Semantica
Concatena inputs
lungo la dimensione dimension
nello stesso ordine degli argomenti dati e produce un tensore result
. A livello più formale, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, dove:
id = d0 + ... + dk-1 + kd
.d
è uguale adimension
ed0
, ... sonod
a dimensione di dimensione diinputs
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1-C6) |
(I2) | dimension |
costante di tipo si64 |
(C2), (C4), (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5-C6) |
Vincoli
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
eccettodim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
eccetto:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Esempi
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
costante
Semantica
Genera un tensore output
da una costante value
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | value |
costante | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore o quantizzato | (C1) |
Vincoli
- (C1)
type(value) = type(output)
.
Esempi
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
effettuare una conversione
Semantica
Esegue una conversione a livello di elemento da un tipo di elemento a un altro sul
tensore operand
e produce un tensore result
.
Per le conversioni di tipo boolean-to-any-supported-type, il valore false
viene
convertito in zero, mentre il valore true
viene convertito in uno. Per le conversioni any-supported-type-to-boolean, un valore zero viene convertito in false
, mentre i valori diversi da zero vengono convertiti in true
. Vedi di seguito come funziona
per i tipi complessi.
Per le conversioni che implicano il valore intero per numero intero, da intero a virgola mobile o da virgola mobile a virgola mobile, se il valore di origine può essere esattamente rappresentato nel tipo di destinazione, il valore risultante è quella rappresentazione esatta. In caso contrario, il comportamento è da definire (#180).
Per le conversioni che coinvolgono floating-point-to-integer, la parte frazionata viene troncata. Se il valore troncato non può essere rappresentato nel tipo di destinazione, il comportamento è da definire (#180).
Le conversioni che coinvolgono le conversioni da complesse a complesse seguono lo stesso comportamento delle conversioni da punto a virgola mobile per la conversione di parti reali e immaginarie.
Per le conversioni complex-to-any-other-type e any-other-type-to-complex, il valore immaginario di origine viene ignorato o il valore immaginario di destinazione viene azzerato, rispettivamente. La conversione della parte reale segue le conversioni in virgola mobile.
In linea di principio, questa operazione potrebbe esprimere la dequantizzazione (conversione da tensori quantizzati a tensori regolari), quantizzazione (conversione da tensori regolari a tensori quantizzati) e riquantizzazione (conversione tra tensori quantizzati), ma al momento esistono operazioni dedicate: uniform_dequantize
per il primo caso d'uso e uniform_quantize
per il secondo e il terzo caso d'uso. In futuro, queste due operazioni potrebbero essere unite
in convert
(#1576).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore | (C1) |
Vincoli
- (C1)
shape(operand) = shape(result)
.
Esempi
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convoluzione
Semantica
Calcola il punto dei prodotti tra le finestre di lhs
e le sezioni di rhs
e produce result
. Il seguente diagramma mostra come gli elementi in result
vengono calcolati da lhs
e rhs
utilizzando un esempio concreto.
In modo più formale, valuta la seguente riorganizzazione degli input in termini di lhs
per poter esprimere finestre di lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Questa riformulazione utilizza le seguenti funzioni helper:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
dovej[d] = i[permutation[d]]
.
Se feature_group_count = 1
e batch_group_count = 1
, per tutti
output_spatial_index
in index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
dove:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Questa funzionalità sembra inutilizzata, perciò abbiamo intenzione di rimuoverla (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Se feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Se batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1), (C10-C11), (C14) (C25), (C27-C30) |
(I2) | rhs |
tensore o quantizzato | (C1), (C14-C16), (C25), (C27-C32) |
(I3) | window_strides |
Costante del tensore unidimensionale di tipo si64 |
(C2-C3), (C25) |
(I4) | padding |
Costante del tensore bidimensionale di tipo si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Costante del tensore unidimensionale di tipo si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Costante del tensore unidimensionale di tipo si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Costante del tensore unidimensionale di tipo i1 |
(C9) |
(I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C25), (C32) |
(I13) | kernel_spatial_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
costante di tipo si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C25), (C33) |
(I16) | output_spatial_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
numero variadico di enumerazioni di DEFAULT , HIGH e HIGHEST |
(C24) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o quantizzato | (C25-C28), (C30-C31), (C33) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dati
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dati
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dati
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25) Per
dim(result, result_dim)
si intende:dim(lhs, input_batch_dimension) / batch_group_count
seresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
seresult_dim = output_feature_dimension
.num_windows
altrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
. - (C29)
storage_type(lhs) = storage_type(rhs)
. - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C31) Se
is_per_tensor_quantized(rhs)
,is_per_tensor_quantized(result)
. - (C32) Se
is_per_axis_quantized(rhs)
,quantization_dimension(rhs) = kernel_output_feature_dimension
. - (C33) Se
is_per_axis_quantized(result)
,quantization_dimension(result) = output_feature_dimension
.
- (C28)
Esempi
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semantica
Esegue un'operazione coseno a livello di elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
cos
dallo standard IEEE-754. - Per i numeri complessi: coseno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(cosine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantica
Esegue il conteggio a livello di elemento del numero di zero bit iniziali nel tensore operand
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantica
Incapsula un'operazione call_target_name
definita dall'implementazione che prende
inputs
e called_computations
e produce results
. has_side_effect
,
backend_config
e api_version
possono essere utilizzati per fornire metadati aggiuntivi
definiti dall'implementazione.
Al momento, questa operazione contiene una raccolta di metadati abbastanza disorganizzata che riflette l'evoluzione organica della sua operazione di controparte nel compilatore XLA. In futuro, abbiamo in programma di unificare questi metadati (#741).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di valori |
(I2) | call_target_name |
costante di tipo string |
(I3) | has_side_effect |
costante di tipo i1 |
(I4) | backend_config |
costante di tipo string |
(I5) | api_version |
costante di tipo si32 |
(I6) | called_computations |
numero variadico di costanti di tipo string |
Output
Nome | Tipo |
---|---|
results |
numero variadico di valori |
Esempi
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
divisione
Semantica
Esegue la divisione per elemento dei tensori del dividendo lhs
e del divisore rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri interi: divisione intera che produce il quoziente algebrico, scartando qualsiasi parte frazionaria.
- Per i numeri in virgola mobile:
division
dallo standard IEEE-754. - Per i numeri complessi: divisione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantica
Calcola che punteggiano i prodotti tra le sezioni di lhs
e le sezioni di rhs
e produce un tensore result
.
In modo più formale, result[result_index] = dot_product
, dove:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dovesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
esize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Specifica solo la semantica per la quantizzazione per tensore. La quantizzazione per asse è in corso (#1574). In futuro, inoltre, potremmo valutare l'aggiunta del supporto per la quantizzazione ibrida (#1575).
precision_config
controlla il compromesso tra velocità e precisione per i calcoli sui backend degli acceleratori. Può essere uno dei seguenti (al momento la semantica di questi valori di enum è sottospecificata, ma abbiamo intenzione di affrontarla in #755):
DEFAULT
: calcolo più veloce, ma meno accurata dell'approssimazione al numero originale.HIGH
: calcolo più lento, ma approssimazione più precisa al numero originale.HIGHEST
: calcolo più lento, ma con un'approssimazione più precisa al numero originale.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C5-C6), (C9-C10), (C12-C16) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C7-C10), (C12) |
(I3) | lhs_batching_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4), (C8), (C10) |
(I7) | precision_config |
numero variadico di enumerazioni di DEFAULT , HIGH e HIGHEST |
(C11) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C12), (C14), (C16) |
Vincoli
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Se l'operazione utilizza tensori non quantizzati:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Se l'operazione utilizza tensori quantizzati:
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C15)
storage_type(lhs) = storage_type(rhs)
. - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C17)
zero_points(rhs) = 0
.
- (C14)
Esempi
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
Semantica
Estrae una sezione da operand
utilizzando indici iniziali calcolati in modo dinamico
e produce un tensore result
. start_indices
contiene gli indici iniziali della sezione per ogni dimensione soggetta a potenziale aggiustamento, mentre slice_sizes
contiene le dimensioni della sezione per ogni dimensione. Dal punto di vista più formale,
result[result_index] = operand[operand_index]
dove:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
(I2) | start_indices |
numero variadico di tensori 0-dimensionali di tipo intero | (C2), (C3) |
(I3) | slice_sizes |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Esempi
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantica
Genera un tensore result
uguale al tensore operand
, tranne per il fatto che la sezione che inizia da start_indices
viene aggiornata con i valori in update
.
Più formalmente, con result[result_index]
si intende:
update[update_index]
se0 <= update_index < shape(update)
dove:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C4), (C6) |
(I2) | update |
tensore o tensore quantizzato per tensore | (C2), (C3), (C6) |
(I3) | start_indices |
numero variadico di tensori 0-dimensionali di tipo intero | (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Esempi
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
esponenziale
Semantica
Esegue un'operazione esponenziale a livello di elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
exp
dallo standard IEEE-754. - Per i numeri complessi: esponenziale complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantica
Esegue esponenziale esponenziale a livello di elemento meno un'operazione sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
expm1
dallo standard IEEE-754. - Per i numeri complessi: esponenziale complesso meno uno.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantica
Esegue le trasformazioni di Fourier diretta e inversa per input/output reali e complessi.
fft_type
è uno dei seguenti:
FFT
: inoltro di FFT complesso-complesso.IFFT
: FFT da complesso a complesso inversa.RFFT
: inoltro di FFT reale a complesso.IRFFT
: FFT reale-complesso inversa (ad esempio, richiede complesso, restituisce risultati reali).
Più formalmente, data la funzione fft
che prende tensori unidimensionali di tipi complessi come input, produce tensori unidimensionali degli stessi tipi dell'output e calcola la trasformata discreta di Fourier:
Per fft_type = FFT
, result
è definito come il risultato finale di una serie di calcoli L in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Inoltre, data la funzione ifft
che ha lo stesso tipo di firma e calcola l'inverso di fft
:
Per fft_type = IFFT
, result
è definito come l'inverso dei calcoli per fft_type = FFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Inoltre, data la funzione rfft
, che accetta tensori unidimensionali dei tipi di rappresentazione in virgola mobile, produce tensori unidimensionali di tipi complessi con la stessa semantica in virgola mobile e funziona nel seguente modo:
rfft(real_operand) = truncated_result
dovecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Quando viene calcolata la trasformata di Fourier discreta per gli operandi reali, i primi
N/2 + 1
elementi del risultato definiscono in modo univoco il resto del risultato,
quindi il risultato di rfft
viene troncato per evitare il calcolo di elementi ridondanti).
Per fft_type = RFFT
, result
è definito come il risultato finale di una serie di calcoli L in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Infine, data la funzione irfft
che ha la stessa firma del tipo e calcola l'inverso di rfft
:
Per fft_type = IRFFT
, result
è definito come l'inverso dei calcoli per fft_type = RFFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum di FFT , IFFT , RFFT e IRFFT |
(C2), (C5) |
(I3) | fft_length |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C3), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso | (C2), (C4), (C5) |
Vincoli
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relazione tra i tipi di elementi
operand
eresult
varia:- Se
fft_type = FFT
,element_type(operand)
eelement_type(result)
hanno lo stesso tipo complesso. - Se
fft_type = IFFT
,element_type(operand)
eelement_type(result)
hanno lo stesso tipo complesso. - Se
fft_type = RFFT
,element_type(operand)
è un tipo con virgola mobile eelement_type(result)
è un tipo complesso con la stessa semantica in virgola mobile. - Se
fft_type = IRFFT
,element_type(operand)
è un tipo complesso eelement_type(result)
è un tipo con virgola mobile con la stessa semantica in virgola mobile.
- Se
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Se tra
operand
eresult
, esiste un tensorereal
di tipo in virgola mobile, quindishape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
eccetto:- Se
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Se
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Se
Esempi
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
Semantica
Esegue piano per elemento di operand
tensore e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardNegative
dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(floor, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
raccogliere
Semantica
Raccoglie sezioni dal tensore operand
dagli offset specificati in start_indices
e produce un tensore result
.
Il seguente diagramma mostra come gli elementi in result
vengono mappati sugli elementi in
operand
utilizzando un esempio concreto. Il diagramma seleziona alcuni indici result
di esempio e spiega in dettaglio a quali indici operand
corrispondono.
Più formalmente, result[result_index] = operand[operand_index]
dove:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.- Per
start_index
si intende:start_indices[bi0, ..., :, ..., biN]
, dovebi
sono singoli elementi inbatch_index
e:
è inserito nell'indiceindex_vector_dim
, seindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
in caso contrario.
- Per
d_operand
inaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sed_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
in caso contrario.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, doveoi
sono singoli elementi dioffset_index
e0
è inserito negli indici dicollapsed_slice_dims
.operand_index = full_start_index + full_offset_index
.
Se indices_are_sorted
è true
, l'implementazione può presupporre che
start_indices
siano ordinati rispetto a start_index_map
, altrimenti il
comportamento non è definito. Più formalmente, per tutti i i1 < i2
di indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensore di tipo intero | (C2), (C3), (C13) |
(I3) | offset_dims |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C4-C5), (C13) |
(I4) | collapsed_slice_dims |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C6-C8), (C13) |
(I5) | start_index_map |
Costante del tensore unidimensionale di tipo si64 |
(C3), (C9), (C10) |
(I6) | index_vector_dim |
costante di tipo si64 |
(C2), (C3), (C13) |
(I7) | slice_sizes |
Costante del tensore unidimensionale di tipo si64 |
(C8), (C11-C13) |
(I8) | indices_are_sorted |
costante di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5), (C13-C14) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dove:batch_dim_sizes = shape(start_indices)
tranne che le dimensioni distart_indices
corrispondenti aindex_vector_dim
non sono incluse.offset_dim_sizes = shape(slice_sizes)
tranne per il fatto che le dimensioni delle dimensioni inslice_sizes
corrispondenti acollapsed_slice_dims
non sono incluse.combine
collocabatch_dim_sizes
su assi corrispondenti abatch_dims
eoffset_dim_sizes
su assi corrispondenti aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Esempi
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
Semantica
Genera le dimensioni dell'elemento dimension
specificato di operand
. In modo più formale,
result = dim(operand, dimension)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore | (C1) |
(I2) | dimension |
costante di tipo si64 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
Tensore 0-dimensionale di tipo si32 |
Vincoli
- (C1)
0 <= dimension < rank(operand)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantica
Estrae l'elemento nella posizione index
della tupla operand
e produce un elemento result
. Più formalmente, result = operand[index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
costante di tipo si32 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
qualsiasi tipo supportato | (C2) |
Vincoli
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Esempi
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Semantica
Restituisce l'output dall'esecuzione di una sola funzione di true_branch
o false_branch
, a seconda del valore di pred
. Più formalmente, result =
pred ? true_branch() : false_branch()
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
Tensore 0-dimensionale di tipo i1 |
|
(I2) | true_branch |
funzione | (C1-C3) |
(I3) | false_branch |
funzione | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C3) |
Vincoli
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Esempi
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
immaginazione
Semantica
Estrae la parte immaginaria, per elemento, da operand
e produce un tensore result
. Più formalmente, per ogni elemento x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
è definito come:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
annuncio in-feed
Semantica
Legge i dati dal feed e produce results
.
La semantica di infeed_config
è definita dall'implementazione.
results
è costituito da valori di payload che vengono prima e un token che arriva per ultimo. In futuro, abbiamo in programma di suddividere il payload e il token in due output separati per migliorare la chiarezza (#670).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
costante di tipo string |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C1-C3) |
Vincoli
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Esempi
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantica
Riempie un tensore output
con valori in ordine crescente a partire da zero
lungo la dimensione iota_dimension
. In modo più formale,
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
0 <= iota_dimension < rank(output)
.
Esempi
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantica
Esegue un controllo a livello di elemento se il valore in x
è finito (ovvero non è né
+Inf, -Inf né NaN) e produce un tensore y
. Implementa l'operazione isFinite
dalla specifica IEEE-754. Per i tipi quantizzati, il risultato è sempre true
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | x |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
y |
tensore di tipo booleano | (C1) |
Vincoli
- (C1)
shape(x) = shape(y)
.
Esempi
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantica
Esegue un'operazione logaritmica per elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
log
dallo standard IEEE-754. - Per i numeri complessi: logaritmo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(log, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantica
Esegue il logaritmo a livello di elemento più un'operazione sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
logp1
dallo standard IEEE-754. - Per i numeri complessi: logaritmo complesso più uno.
- Per i tipi quantizzati:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistica
Semantica
Esegue un'operazione logistica per elementi sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
division(1, addition(1, exp(-x)))
dallo standard IEEE-754. - Per i numeri complessi: logistica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(logistic, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mappa
Semantica
Applica una funzione della mappa computation
a inputs
lungo il dimensions
e produce un tensore result
.
Più formalmente, result[result_index] = computation(inputs...[result_index])
.
Tieni presente che al momento dimensions
non sono utilizzati e probabilmente verranno rimossi in futuro (#487).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1-C4) |
(I2) | dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C3) |
(I3) | computation |
funzione | (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C4) |
Vincoli
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
ha il tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, doveEi = element_type(inputs[i])
eE' = element_type(result)
.
Esempi
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
massima
Semantica
Esegue un'operazione massima per elemento sui tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: OR logico.
- Per i numeri interi: massimo numero intero.
- Per i numeri in virgola mobile:
maximum
dallo standard IEEE-754. - Per i numeri complessi: il valore lessicografico massimo per la coppia
(real, imaginary)
. L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minima
Semantica
Esegue un'operazione minima per elemento sui tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: AND logico.
- Per i numeri interi: minimo numero intero.
- Per i numeri in virgola mobile:
minimum
dallo standard IEEE-754. - Per i numeri complessi: il minimo lessicografico per la coppia
(real, imaginary)
. L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
moltiplicazione
Semantica
Esegue il prodotto a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: AND logico.
- Per i numeri interi: moltiplicazione di numeri interi.
- Per i numeri in virgola mobile:
multiplication
dallo standard IEEE-754. - Per i numeri complessi: moltiplicazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negare
Semantica
Esegue la negazione a livello di elemento del tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri interi firmati: negazione di numeri interi.
- Per numeri interi senza segno: bitcast in numero intero con segno, negazione di numeri interi, bitcast di nuovo su numero intero senza segno.
- Per i numeri in virgola mobile:
negate
dallo standard IEEE-754. - Per i numeri complessi: negazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(negate, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
non
Semantica
Esegue NOT per elementi del tensore operand
e produce un tensore result
.
A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: NOT logico.
- Per i numeri interi: NOT a livello di bit.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
tensore di tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantica
Garantisce che le operazioni che producono operand
vengano eseguite prima di qualsiasi operazione che dipendono da result
e impedisce alle trasformazioni del compilatore di spostare le operazioni attraverso la barriera. A parte questo, l'operazione è
un'identità, ovvero result = operand
.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
numero variadici di tensori, tensori o token quantizzati per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
numero variadici di tensori, tensori o token quantizzati per tensore | (C1) |
Vincoli
- (C1)
type(operand...) = type(result...)
.
Esempi
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semantica
Esegue la funzione OR a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: OR logico.
- Per i numeri interi: OR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero o booleano | (C1) |
(I2) | rhs |
tensore di tipo intero o booleano | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero o booleano | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
outfeed
Semantica
Scrive inputs
nell'outfeed e produce un token result
.
La semantica di outfeed_config
è definita dall'implementazione.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati |
(I2) | token |
token |
(I3) | outfeed_config |
costante di tipo string |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
cuscinetto
Semantica
Espande operand
inserendo una spaziatura interna attorno al tensore e tra gli elementi del tensore con il valore padding_value
specificato.
edge_padding_low
e edge_padding_high
specificano la quantità di spaziatura interna aggiunta rispettivamente alla fascia bassa (accanto all'indice 0) e alla fascia alta (accanto all'indice più alto) di ogni dimensione. La quantità di spaziatura interna può essere negativa, dove il valore assoluto di spaziatura interna negativa indica il numero di elementi da rimuovere dalla dimensione specificata.
interior_padding
specifica la quantità di spaziatura interna aggiunta tra due elementi qualsiasi in ogni dimensione che non può essere negativa. La spaziatura interna interna viene eseguita prima della spaziatura interna del bordo, in modo tale che quella negativa dei bordi rimuoverà gli elementi dall'operando con riempimento interno.
Più formalmente, con result[result_index]
si intende:
operand[operand_index]
seresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
(I2) | padding_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
(I3) | edge_padding_low |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Costante del tensore unidimensionale di tipo si64 |
(C1), (C4) |
(I5) | interior_padding |
Costante del tensore unidimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantica
Produce partition_id
del processo corrente.
Output
Nome | Tipo |
---|---|
result |
Tensore 0-dimensionale di tipo ui32 |
Esempi
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semantica
Esegue il conteggio a livello di elemento del numero di bit impostato nel tensore operand
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
potenza
Semantica
Esegue l'esponenziale a livello di elemento del tensore lhs
per un tensore rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri interi: esponenzialità dei numeri interi.
- Per i numeri in virgola mobile:
pow
dallo standard IEEE-754. - Per i numeri complessi: esponenzialità complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
reale
Semantica
Estrae la parte reale, a livello di elemento, da operand
e produce un tensore result
. Più formalmente, per ogni elemento x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
è definito come:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantica
Riceve i dati da un canale con channel_id
e produce results
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati dall'host. In caso contrario, i dati vengono trasferiti da un altro dispositivo. Ciò significa che è
definito dall'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, pertanto in futuro prevediamo di conservarne solo uno
(#666).
results
è costituito da valori di payload che vengono prima e un token che arriva per ultimo. In futuro, abbiamo in programma di suddividere il payload e il token in due output separati per migliorare la chiarezza (#670).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
costante di tipo si64 |
|
(I3) | channel_type |
enum di DEVICE_TO_DEVICE e HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C2-C4) |
Vincoli
- (C1)
channel_type
è definito come:HOST_TO_DEVICE
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Esempi
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantica
Applica una funzione di riduzione body
a inputs
e init_values
lungo il
dimensions
e produce results
tensori.
L'ordine delle riduzioni è definito dall'implementazione, il che significa che body
e init_values
devono formare un monoide per garantire che l'operazione produca gli stessi risultati per tutti gli input in tutte le implementazioni. Tuttavia, questa condizione non si applica a molte riduzioni popolari. Ad esempio, l'aggiunta in virgola mobile per body
e zero per init_values
in realtà non formano un monoide perché l'aggiunta in virgola mobile non è associativa.
Più formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted)
dove:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, in cui gli elementi:
sono inseriti indimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
per un albero binarioschedule
dove:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
è una struttura binaria completa definita dall'implementazione il cui attraversamento in ordine è costituito da:- Valori
input_slices_converted...[index]
, per tutti iindex
inindex_space(input_slices_converted)
nell'ordine lessicografico crescente diindex
. - È alternato a una quantità definita dall'implementazione di
init_values_converted
nelle posizioni definite per l'implementazione.
- Valori
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C7) |
(I2) | init_values |
numero variadico di tensori 0-dimensionali o tensori quantizzati per tensore | (C2), (C3) |
(I3) | dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C4), (C5), (C7) |
(I4) | body |
funzione | (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadici di tensori o tensori quantizzati per tensore | (C3), (C7), (C8) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, tranne per il fatto che le dimensioni diinputs...
corrispondenti adimensions
non sono incluse. - (C8)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantica
Esegue la conversione a livello di elemento di operand
in un altro tipo con virgola mobile che utilizza exponent_bits
e mantissa_bits
, tornando al tipo originale con virgola mobile e produce un tensore output
.
In modo più formale:
- I bit di mantissa del valore originale vengono aggiornati in modo da arrotondare il valore originale al valore più vicino rappresentabile con
mantissa_bits
utilizzando la semanticaroundToIntegralTiesToEven
. - Se
mantissa_bits
è inferiore al numero di bit di mantissa del valore originale, i bit di mantissa vengono troncati amantissa_bits
. - Quindi, se i bit esponenti del risultato intermedio non rientrano nell'intervallo fornito da
exponent_bits
, il risultato intermedio supera il limite all'infinito utilizzando il segno originale o torna a zero utilizzando il segno originale. - Per i tipi quantizzati, esegue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | exponent_bits |
costante di tipo si32 |
(C2) |
(I3) | mantissa_bits |
costante di tipo si32 |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Esempi
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantica
All'interno di ogni gruppo di processi nella griglia di processo StableHLO, esegue la riduzione,
utilizzando computations
, sui valori del tensore operand
di ogni processo,
suddivide il risultato della riduzione in scatter_dimension
in parti e disperde
le parti divise tra i processi per produrre result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definito come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Successivamente, entro ogni process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
per tutti isender
inprocess_group
, dovereceiver_index = process_group.index(receiver)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
costante di tipo si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Costante del tensore bidimensionale di tipo si64 |
(C3-C5) |
(I4) | channel_id |
costante di tipo si64 |
(C6) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C6) |
(I6) | computation |
funzione | (C7) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C8-C9) |
Vincoli
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
è definito come:num_replicas
se viene utilizzatocross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C7)
computation
ha il tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, doveis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
tranne:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantica
Applica una funzione di riduzione body
a finestre di inputs
e init_values
e produce results
.
Il seguente diagramma mostra come vengono calcolati gli elementi in results...
da inputs...
utilizzando un esempio concreto.
In modo più formale,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(vedi reduce) dove:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
numero variadico di tensori 0-dimensionali o tensori quantizzati per tensore | (C1), (C13) |
(I3) | window_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Costante del tensore unidimensionale di tipo si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Costante del tensore unidimensionale di tipo si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Costante del tensore unidimensionale di tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Costante del tensore bidimensionale di tipo si64 |
(C12), (C15) |
(I8) | body |
funzione | (C13) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadici di tensori o tensori quantizzati per tensore | (C1), (C14-C16) |
Vincoli
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
dove:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semantica
Esegue il resto del dividendo lhs
e il divisore rhs
dei tensori in base agli elementi e produce un tensore result
.
Più formalmente, il segno del risultato viene preso dal dividendo e il valore assoluto del risultato è sempre inferiore al valore assoluto del divisore.
Il resto viene calcolato come lhs - d * rhs
, dove d
è dato da:
- Per i numeri interi:
stablehlo.divide(lhs, rhs)
. - Per i numeri in virgola mobile:
division(lhs, rhs)
da IEEE-754 con attributo di arrotondamentoroundTowardZero
. - Per i numeri complessi: da definire (#997).
- Per i tipi quantizzati:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Per i tipi di elementi con virgola mobile, questa operazione è in contrasto con l'operazione remainder
della specifica IEEE-754 in cui d
è un valore integrale più vicino al valore esatto di lhs/rhs
con legami a pari.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantica
Produce replica_id
del processo corrente.
Output
Nome | Tipo |
---|---|
result |
Tensore 0-dimensionale di tipo ui32 |
Esempi
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
rimodellare
Semantica
Esegue la riforma del tensore operand
in un tensore result
. Concettualmente, equivale a mantenere la stessa rappresentazione canonica, ma modificando potenzialmente la forma, ad esempio da tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
Più formalmente, result[result_index] = operand[operand_index]
dove
result_index
e operand_index
hanno la stessa posizione nell'ordine lessicografico
di index_space(result)
e index_space(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o quantizzato | (C1-C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o quantizzato | (C1-C3) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
equantization_dimension(result)
potrebbero differire, altrimenti.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
inverti
Semantica
Inverte l'ordine degli elementi in operand
lungo il dimensions
specificato e produce un tensore result
. Dal punto di vista più formale,
result[result_index] = operand[operand_index]
dove:
operand_index[d] = dim(result, d) - result_index[d] - 1
sed
indimensions
.operand_index[d] = result_index[d]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C3) |
(I2) | dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C3) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Esempi
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantica
Genera numeri casuali utilizzando l'algoritmo rng_distribution
e produce un
tensore result
di una data forma shape
.
Se rng_distribution = UNIFORM
, i numeri casuali vengono generati seguendo la distribuzione uniforme nell'intervallo [a, b)
. Se a >= b
,
il comportamento non è definito.
Se rng_distribution = NORMAL
, i numeri casuali vengono generati seguendo la distribuzione normale con media = a
e deviazione standard = b
.
Se b < 0
, il comportamento non è definito.
Il modo esatto in cui vengono generati i numeri casuali è definito dall'implementazione. Ad esempio, possono essere o meno deterministici e possono utilizzare o meno lo stato nascosto.
Nelle conversazioni con molti stakeholder, questa operazione si è rivelata effettivamente obsoleta, quindi in futuro abbiamo intenzione di valutare la possibilità di rimuoverla (#597).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
Tensore 0-dimensionale di tipo intero, booleano o in virgola mobile | (C1), (C2) |
(I2) | b |
Tensore 0-dimensionale di tipo intero, booleano o in virgola mobile | (C1), (C2) |
(I3) | shape |
Costante del tensore unidimensionale di tipo si64 |
(C3) |
(I4) | rng_distribution |
enum di UNIFORM e NORMAL |
(C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, booleano o in virgola mobile | (C1-C3) |
Vincoli
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Se
rng_distribution = NORMAL
,is_float(a)
. - (C3)
shape(result) = shape
.
Esempi
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantica
Restituisce un output
riempito con bit casuali uniformi e uno stato di output aggiornato output_state
utilizzando l'algoritmo del generatore di numeri pseudocasuale rng_algorithm
in base a uno stato iniziale initial_state
. È garantito che l'output sia una funzione deterministica di initial_state
, ma non è garantito che sia deterministico tra le implementazioni.
rng_algorithm
è uno dei seguenti:
DEFAULT
: algoritmo definito dall'implementazione.THREE_FRY
: variante definita dall'implementazione dell'algoritmo Threefry.*PHILOX
: variante dell'algoritmo Philox definita dall'implementazione.*
* Vedi: Salmon et al. SC 2011. Numeri casuali paralleli: facile come 1, 2, 3.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | rng_algorithm |
enum di DEFAULT , THREE_FRY e PHILOX |
(C2) |
(I2) | initial_state |
Tensore unidimensionale di tipo ui64 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output_state |
Tensore unidimensionale di tipo ui64 |
(C1) |
output |
tensore di tipo intero o in virgola mobile |
Vincoli
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
è definito come:- definita dall'implementazione se
rng_algorithm = DEFAULT
. 2
serng_algorithm = THREE_FRY
.2
o3
serng_algorithm = PHILOX
.
- definita dall'implementazione se
Esempi
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantica
Esegue l'arrotondamento per elemento al numero intero più vicino, separando i legami da zero, sul tensore operand
e produce un tensore result
. Implementa l'operazione roundToIntegralTiesToAway
dalla specifica IEEE-754. Per
i tipi quantizzati, esegue
dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantica
Esegue l'arrotondamento per elemento al numero intero più vicino, spezzando i legami con il numero intero pari, sul tensore operand
e produce un tensore result
. Implementa l'operazione roundToIntegralTiesToEven
dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantica
Esegue un'operazione di radice quadrata reciproca rispetto agli elementi sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
rSqrt
dallo standard IEEE-754. - Per i numeri complessi: radice quadrata reciproca complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Semantica
Genera tensori results
che sono uguali a inputs
tensori, tranne per il fatto che
diverse sezioni specificate da scatter_indices
vengono aggiornate con i valori
updates
utilizzando update_computation
.
Il seguente diagramma mostra come gli elementi in updates...
vengono mappati sugli elementi in
results...
utilizzando un esempio concreto. Il diagramma seleziona alcuni indici updates...
di esempio e spiega in dettaglio a quali indici results...
corrispondono.
Più formalmente, per tutti e update_index
in index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.- Per
start_index
si intende:scatter_indices[si0, ..., :, ..., siN]
, dovesi
sono singoli elementi inupdate_scatter_index
e:
è inserito nell'indiceindex_vector_dim
, seindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
in caso contrario.
- Per
d_input
inaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
sed_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
in caso contrario.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, dovewi
sono singoli elementi diupdate_window_index
e0
è inserito negli indici diinserted_window_dims
.result_index = full_start_index + full_window_index
.
Detto ciò, results = exec(schedule, inputs)
, dove:
schedule
è una permutazione diindex_space(updates[0])
definita dall'implementazione.exec([update_index, ...], results) = exec([...], updated_results)
dove:- Se
result_index
è nei limiti dishape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
è una copia diresults
conresults...[result_index]
impostato suupdated_values...
.- In caso contrario
updated_results = results
.
- Se
exec([], results) = results
.
Se indices_are_sorted
è true
, l'implementazione può presumere che i valori scatter_indices
siano ordinati rispetto a scatter_dims_to_operand_dims
, altrimenti il comportamento non è definito. Più formalmente, per tutti i i1 < i2
di
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Se unique_indices
è true
, l'implementazione può presupporre che tutti gli indici sparsi result_index
siano univoci. Se unique_indices
è true
, ma i valori di dispersione degli indici non sono univoci, il comportamento non è definito.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1), (C2), (C4-C6), (C10), (C13), (C15-C16) |
(I2) | scatter_indices |
tensore di tipo intero | (C4), (C11), (C14) |
(I3) | updates |
numero variadici di tensori o tensori quantizzati per tensore | (C3-C6), (C8) |
(I4) | update_window_dims |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4), (C7), (C8) |
(I5) | inserted_window_dims |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4), (C9), (C10) |
(I6) | scatter_dims_to_operand_dims |
Costante del tensore unidimensionale di tipo si64 |
(C11-C13) |
(I7) | index_vector_dim |
costante di tipo si64 |
(C4), (C11), (C14) |
(I8) | indices_are_sorted |
costante di tipo i1 |
|
(I9) | unique_indices |
costante di tipo i1 |
|
(I10) | update_computation |
funzione | (C15) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadici di tensori o tensori quantizzati per tensore | (C15-C17) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
. - (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
dove:update_scatter_dim_sizes = shape(scatter_indices)
tranne che la dimensione discatter_indices
corrispondente aindex_vector_dim
non è inclusa.update_window_dim_sizes <= shape(inputs[0])
tranne per il fatto che le dimensioni ininputs[0]
corrispondenti ainserted_window_dims
non sono incluse.combine
collocaupdate_scatter_dim_sizes
su assi corrispondenti aupdate_scatter_dims
eupdate_window_dim_sizes
su assi corrispondenti aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
. - (C10)
0 <= inserted_window_dims < rank(inputs[0])
. - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C12)
is_unique(scatter_dims_to_operand_dims)
. - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C15)
update_computation
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, doveis_promotable(element_type(inputs[i]), Ei)
. - (C16)
shape(inputs...) = shape(results...)
. - (C17)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
seleziona
Semantica
Genera un tensore result
in cui ogni elemento viene selezionato da on_true
o on_false
in base al valore dell'elemento corrispondente di pred
.
Più formalmente, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, dove pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Per i tipi quantizzati, esegue dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
tensore di tipo i1 |
(C1) |
(I2) | on_true |
tensore o tensore quantizzato per tensore | (C1-C2) |
(I3) | on_false |
tensore o tensore quantizzato per tensore | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Esempi
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantica
Distribuisce i valori del tensore source
utilizzando scatter
in base al risultato di reduce_window
del tensore input
utilizzando select
e produce un tensore result
.
Il seguente diagramma mostra come gli elementi in result
vengono calcolati da operand
e source
utilizzando un esempio concreto.
In modo più formale:
selected_values = reduce_window_without_init(...)
con i seguenti input:- 'inputs = [operando].
window_dimensions
,window_strides
epadding
utilizzati così come sono.base_dilations = windows_dilations = 1
.body
è definito come:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
dove
E = element_type(operand)
ereduce_window_without_init
funzionano esattamente comereduce_window
, tranne per il fatto cheschedule
del valorereduce
sottostante (vedi riduzione) non include valori init. Al momento non è specificato cosa succede se la finestra corrispondente non contiene valori (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
dove:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
seselected_values[source_index]
ha l'elementooperand
dioperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensore o tensore quantizzato per tensore | (C1), (C2) |
(I3) | init_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C3) |
(I4) | window_dimensions |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C6), (C7) |
(I6) | padding |
Costante del tensore bidimensionale di tipo si64 |
(C2), (C8) |
(I7) | select |
funzione | (C9) |
(I8) | scatter |
funzione | (C10) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C11-C12) |
Vincoli
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
dove:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
ha il tipo(tensor<E>, tensor<E>) -> tensor<i1>
, doveE = element_type(operand)
. - (C10)
scatter
ha il tipo(tensor<E>, tensor<E>) -> tensor<E>
, doveis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Esempi
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
Invia
Semantica
Invia inputs
a un canale channel_id
e produce un token result
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati all'host. In caso contrario, i dati vengono trasferiti su un altro dispositivo. Ciò significa che è
definito dall'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, pertanto in futuro prevediamo di conservarne solo uno
(#666).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati | |
(I2) | token |
token |
|
(I3) | channel_id |
costante di tipo si64 |
|
(I4) | channel_type |
enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
token |
Vincoli
- (C1)
channel_type
è definito come:DEVICE_TO_HOST
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
Esempi
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantica
Esegue un'operazione di spostamento a sinistra per elemento sul tensore lhs
di un numero di bit rhs
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantica
Esegue un'operazione aritmetica di spostamento verso destra per elementi sul tensore lhs
per un numero di bit rhs
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantica
Esegue un'operazione logica di spostamento verso destra in base agli elementi sul tensore lhs
in base al numero di bit rhs
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
firmare
Semantica
Restituisce il segno dell'elemento operand
e produce un tensore result
.
Più formalmente, per ogni elemento x
, la semantica può essere espressa utilizzando la sintassi Python come segue:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Per i tipi quantizzati, esegue dequantize_op_quantize(sign, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numero intero firmato, in virgola mobile, complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numero intero firmato, in virgola mobile, complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semantica
Esegue un'operazione seno a livello di elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
sin
dallo standard IEEE-754. - Per i numeri complessi: seno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(sine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
sezione
Semantica
Estrae una sezione da operand
utilizzando indici iniziali calcolati in modo statico e produce un tensore result
. start_indices
contiene gli indici iniziali della sezione per ogni dimensione, limit_indices
contiene gli indici finali (esclusivi) per la sezione di ogni dimensione e strides
contiene gli incrementi per ogni dimensione.
Più formalmente, result[result_index] = operand[operand_index]
dove
operand_index = start_indices + result_index * strides
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C3), (C5) |
(I2) | start_indices |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C3), (C5) |
(I4) | strides |
Costante del tensore unidimensionale di tipo si64 |
(C2), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Esempi
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
ordinare
Semantica
Ordina le sezioni unidimensionali di inputs
lungo la dimensione dimension
,
in base a un valore comparator
e produce results
.
A differenza degli input simili in altre operazioni, dimension
consente valori negativi,
con la semantica descritta di seguito. In futuro, questa impostazione potrebbe non essere consentita per motivi di coerenza (#1377).
Se is_stable
è true, l'ordinamento è stabile, ovvero l'ordine relativo degli elementi considerati uguali dal comparatore viene conservato. Nel caso in cui sia presente un singolo input, due elementi e1
e e2
vengono considerati uguali dal comparatore se e solo se comparator(e1, e2) = comparator(e2, e1) = false
. Vedi la formalizzazione di seguito per
come si generalizza a più input.
Più formalmente, per tutti e result_index
in index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, doveriN
sono singoli elementi diresult_index
e:
è inserito inadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- dove
sort
ordina una sezione unidimensionale in ordine non decrescente prevedendo checomparator_together
restituiscatrue
se l'argomento a sinistra è inferiore all'argomento del secondo a destra. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadici di tensori o tensori quantizzati per tensore | (C1-C5) |
(I2) | dimension |
costante di tipo si64 |
(C4) |
(I3) | is_stable |
costante di tipo i1 |
|
(I4) | comparator |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadici di tensori o tensori quantizzati per tensore | (C2), (C3) |
Vincoli
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, doveR = rank(inputs[0])
. - (C5)
comparator
ha il tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, doveEi = element_type(inputs[i])
.
Esempi
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantica
Esegue un'operazione di radice quadrata per elemento sul tensore operand
e produce un
tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
squareRoot
dallo standard IEEE-754. - Per i numeri complessi: radice quadrata complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(sqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantica
Esegue la sottrazione a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri interi: sottrazione di numeri interi.
- Per i numeri in virgola mobile:
subtraction
dallo standard IEEE-754. - Per i numeri complessi: sottrazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
Tanh
Semantica
Esegue un'operazione di tangente iperbolica per elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i numeri in virgola mobile:
tanh
dallo standard IEEE-754. - Per i numeri complessi: tangente iperbolica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tanh, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
trasponi
Semantica
Rimuove le dimensioni del tensore operand
utilizzando permutation
e produce un tensore result
. A livello più formale, result[result_index] = operand[operand_index]
dove result_index[d] = operand_index[permutation[d]]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o quantizzato | (C1-C4) |
(I2) | permutation |
Costante del tensore unidimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o quantizzato | (C1), (C3-C4) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
equantization_dimension(result)
potrebbero differire, altrimenti.
- (C2)
permutation
è una permutazione dirange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Se
is_per_axis_quantized(result)
,quantization_dimension(operand) = permutation(quantization_dimension(result))
.
Esempi
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantica
Risolve batch di sistemi di equazioni lineari con matrici di coefficienti triangolari inferiore o superiore.
Più formalmente, dati a
e b
, result[i0, ..., iR-3, :, :]
è la soluzione
per op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
quando left_side
è
true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
quando
left_side
è false
, risolvendo la variabile x
dove op(a)
è determinato
da transpose_a
, che può essere uno dei seguenti:
NO_TRANSPOSE
: esegui l'operazione utilizzandoa
così com'è.TRANSPOSE
: esegui un'operazione sulla trasposizione dia
.ADJOINT
: esegui un'operazione sulla trasposizione coniugata dia
.
I dati di input vengono letti solo dal triangolo inferiore a
, se lower
è true
o dal triangolo superiore di a
. I dati di output vengono restituiti nello stesso triangolo;
i valori nell'altro triangolo sono definiti nell'implementazione.
Se unit_diagonal
è true, l'implementazione può presumere che gli elementi diagonali di a
siano uguali a 1, altrimenti il comportamento non è definito.
Per i tipi quantizzati, esegue dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | b |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1-C4) |
(I3) | left_side |
costante di tipo i1 |
(C3) |
(I4) | lower |
costante di tipo i1 |
|
(I5) | unit_diagonal |
costante di tipo i1 |
|
(I6) | transpose_a |
enum di NO_TRANSPOSE , TRANSPOSE e ADJOINT |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relazione tra
shape(a)
eshape(b)
viene definita come segue:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Esempi
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semantica
Genera una tupla result
dai valori val
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | val |
numero variadico di valori | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tuple | (C1) |
Vincoli
- (C1)
result
ha il tipotuple<E0, ..., EN-1>
, doveEi = type(val[i])
.
Esempi
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantica
Esegue la conversione a livello di elemento del tensore quantizzato operand
in un tensore in virgola mobile result
in base ai parametri di quantizzazione definiti dal tipo operand
.
Più formalmente, result = dequantize(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Esempi
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantica
Esegue la conversione a livello di elemento di un tensore a virgola mobile o di un tensore quantizzato operand
in un tensore quantizzato result
in base ai parametri di quantizzazione definiti dal tipo result
.
In modo più formale,
- Se
is_float(operand)
:result = quantize(operand, type(result))
.
- Se
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Esempi
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mentre
Semantica
Restituisce l'output dall'esecuzione della funzione body
0 o più volte, mentre la funzione cond
restituisce true
. In termini più formali, la semantica può essere espressa
utilizzando la sintassi Python:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Il comportamento di un loop infinito è da definire (#383).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
numero variadico di tensori, tensori quantizzati o token | (C1-C3) |
(I2) | cond |
funzione | (C1) |
(I3) | body |
funzione | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C3) |
Vincoli
- (C1)
cond
ha il tipo(T0, ..., TN-1) -> tensor<i1>
, doveTi = type(operand[i])
. - (C2)
body
ha il tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, doveTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Esempi
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Semantica
Esegue XOR per elementi di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, procedi nel seguente modo:
- Per i valori booleani: XOR logico.
- Per i numeri interi: XOR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Attuazione
Esecuzione sequenziale
Viene eseguito un programma StableHLO fornendo valori di input alla funzione main
e calcolando i valori di output. I valori di output di una funzione vengono calcolati eseguendo il grafico delle operazioni radicate nell'operazione return
corrispondente.
L'ordine di esecuzione è definito dall'implementazione purché sia in linea con Dataflow, ovvero se le operazioni vengono eseguite prima dei loro utilizzi. In StableHLO, tutte le operazioni con effetti collaterali consumano un token e ne producono uno (più token possono essere multiplexati in un token tramite after_all
), quindi anche l'ordine di esecuzione degli effetti collaterali è in linea con Dataflow. Possibili ordini di esecuzione del programma
di esempio sopra riportato sono %0
→ %1
→ %2
→ %3
→ %4
→ return
o %3
→ %0
→
%1
→ %2
→ %4
→ return
.
Più formalmente, un processo StableHLO è una combinazione di:
1) un programma StableHLO, 2) stati delle operazioni (non ancora eseguiti,
già eseguiti) e 3) valori intermedi su cui sta lavorando il processo.
Il processo inizia con i valori di input della funzione main
, avanza nel grafico delle operazioni che aggiornano gli stati delle operazioni e i valori intermedi e termina con i valori di output. È da definire un'ulteriore formalizzazione (#484).
Esecuzione parallela
I programmi StableHLO possono essere eseguiti in parallelo, organizzati in una griglia di processi 2D di num_replicas
per num_partitions
, entrambi di tipo ui32
.
Nella griglia dei processi StableHLO, sono in esecuzione contemporaneamente num_replicas * num_partitions
processi StableHLO. Ogni processo ha un valore process_id = (replica_id, partition_id)
univoco, dove replica_id
in replica_ids = range(num_replicas)
e partition_id
in partition_ids = range(num_partitions)
sono entrambi di tipo ui32
.
La dimensione della griglia dei processi è nota in modo statico per ogni programma (in futuro, prevediamo di renderla una parte esplicita dei programmi StableHLO #650) e la posizione all'interno della griglia dei processi è nota in modo statico per ogni processo. Ogni processo ha accesso alla propria posizione all'interno della griglia dei processi tramite le operazioni replica_id
e partition_id
.
All'interno della griglia dei processi, i programmi possono essere tutti uguali (nello stile "Programma singolo, più dati"), possono essere tutti diversi (nello stile "Programmi multipli e Più dati") o viceversa. In futuro, abbiamo in programma di introdurre il supporto per altre espressioni idiomatiche per la definizione di programmi StableHLO paralleli, tra cui GSPMD (#619).
All'interno della griglia dei processi, i processi sono per lo più indipendenti l'uno dall'altro: hanno stati operativi distinti, valori di input/intermedio/output separati e la maggior parte delle operazioni viene eseguita separatamente tra i processi, ad eccezione di un numero ridotto di operazioni collettive descritte di seguito.
Dato che l'esecuzione della maggior parte delle operazioni utilizza solo valori dello stesso processo, di solito è inequivocabile fare riferimento a questi valori tramite i loro nomi.
Tuttavia, quando si descrive la semantica delle operazioni collettive, ciò non è sufficiente e dà origine alla notazione name@process_id
per fare riferimento al valore name
all'interno di un determinato processo. (Da questo punto di vista, name
non qualificato può essere
visualizzato come una forma abbreviata di name@(replica_id(), partition_id())
).
L'ordine di esecuzione nei processi è definito dall'implementazione, ad eccezione della sincronizzazione introdotta dalla comunicazione point-to-point e dalle operazioni collettive come descritto di seguito.
Comunicazione da punto a punto
I processi StableHLO possono comunicare tra loro tramite i canali StableHLO. Un canale è rappresentato da un ID positivo di tipo si64
. Attraverso varie operazioni, è possibile inviare valori ai canali e riceverli dai canali.
È ancora da definire un'ulteriore formalizzazione, ad esempio da dove provengono questi ID canale, in che modo i programmi dei processi vengono a conoscenza di questi ID e quale tipo di sincronizzazione viene loro introdotto (#484).
Comunicazione in streaming
Ogni processo StableHLO ha accesso a due interfacce di streaming:
- Infeed che è possibile leggere.
- Outfeed in cui è possibile scrivere.
A differenza dei canali, che sono utilizzati per comunicare tra i processi e quindi hanno processi a entrambi gli estremi, gli annunci in-feed e in uscita hanno l'altra implementazione definita.
È da definire un'ulteriore formalizzazione, ad esempio in che modo la comunicazione in streaming influenza l'ordine di esecuzione e il tipo di sincronizzazione introdotto da quest'ultima (#484).
Operazioni collettive
In StableHLO sono presenti sei operazioni collettive: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
e
reduce_scatter
. Tutte queste operazioni suddividono i processi nella griglia di processi StableHLO in gruppi di processi StableHLO ed eseguono un calcolo congiunto all'interno di ogni gruppo di processi, indipendentemente dagli altri gruppi di processi.
All'interno di ogni gruppo di processi, le operazioni collettive possono introdurre una barriera di sincronizzazione. Un'ulteriore formalizzazione, ad esempio l'elaborazione di quando avviene esattamente questa sincronizzazione, in che modo i processi arrivano a questa barriera e cosa succede in caso contrario, è da definire (#484).
Se il gruppo di processi prevede una comunicazione tra partizioni, ad esempio se esistono processi nel gruppo i cui ID partizioni sono diversi, l'esecuzione dell'operazione collettiva richiede un canale, mentre l'operazione collettiva deve fornire un valore channel_id
positivo di tipo si64
. La comunicazione con replica incrociata non
ha bisogno di canali.
I calcoli eseguiti dalle operazioni collettive sono specifici delle singole operazioni e sono descritti nelle singole sezioni delle operazioni sopra riportate. Tuttavia, le strategie in base alle quali la griglia di processo viene suddivisa in gruppi di processi sono condivise tra queste operazioni e sono descritte in questa sezione. Più formalmente, StableHLO supporta le seguenti quattro strategie.
cross_replica
Solo le comunicazioni con replica incrociata si verificano all'interno di ogni gruppo di processi. Questa strategia prende replica_groups
, un elenco di elenchi di ID replica, e calcola un prodotto cartesiano di replica_groups
in base a partition_ids
. replica_groups
deve avere elementi univoci e coprire tutti i replica_ids
. In modo più formale, usando
la sintassi Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica
produrrà
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Solo le comunicazioni tra partizioni si verificano all'interno di ogni gruppo di processi. Questa strategia prende partition_groups
, un elenco di elenchi di ID partizioni, e calcola un prodotto cartesiano di partition_groups
in base a replica_ids
.
partition_groups
deve avere elementi univoci e coprire tutti i partition_ids
.
Utilizzando la sintassi Python in modo più formale:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per partition_groups = [[0, 1]]
e num_replicas = 4
,
cross_partition
produrrà
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Le comunicazioni con replica incrociata e tra partizioni possono verificarsi all'interno di ciascun
gruppo di processi. Questa strategia utilizza replica_groups
, un elenco di elenchi di ID di replica, e calcola i prodotti cartesiani di ogni replica_group
per partition_ids
. replica_groups
deve avere elementi univoci e coprire tutti i replica_ids
. Utilizzando la sintassi Python in modo più formale:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica_and_partition
produrrà
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Questa strategia prende flattened_id_groups
, un elenco di elenchi di ID di processo "appiattiti" sotto forma di replica_id * num_partitions + partition_id
, e li trasforma in ID di processo. flattened_id_groups
deve avere elementi univoci
e coprire tutti i process_ids
. Utilizzando la sintassi Python in modo più formale:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
e num_partitions = 2
, flattened_ids
produrrà
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Accuratezza
Al momento StableHLO non fornisce garanzie sull'accuratezza numerica, ma questa impostazione potrebbe cambiare in futuro (#1156).
Errori
I programmi StableHLO sono convalidati attraverso una vasta serie di vincoli per le singole operazioni, che esclude molte classi di errori prima dell'esecuzione. Tuttavia, le condizioni di errore sono comunque possibili, ad esempio tramite overflow di numeri interi, accessi fuori dai limiti e così via. Se non esplicitamente indicati, tutti questi errori generano un comportamento definito dall'implementazione, che però potrebbe cambiare in futuro (#1157).
Come eccezione a questa regola, le eccezioni in virgola mobile nei programmi StableHLO hanno un comportamento ben definito. Le operazioni che comportano eccezioni definite dallo standard IEEE-754 (operazione non valida, divisione per zero, overflow, underflow o eccezioni inesatte) producono risultati predefiniti (come definiti nello standard) e continuano l'esecuzione senza aumentare il flag di stato corrispondente; in modo simile alla gestione delle eccezioni raiseNoFlag
dallo standard. Le eccezioni per le operazioni non standard (ad es. l'aritmetica complessa e alcune funzioni trascendenti) sono definite dall'implementazione.
Notazione
Per descrivere la sintassi, questo documento utilizza il sapore ISO modificato della sintassi EBNF (ISO/IEC 14977:1996,
Wikipedia), con due modifiche: 1) le regole vengono definite utilizzando ::=
anziché =
,
2) la concatenazione viene espressa utilizzando la giustapposizione anziché ,
.
Per descrivere la semantica (ad es. all'interno delle sezioni "Tipi", "Costanti" e "Operazioni"), utilizziamo formule basate sulla sintassi Python estesa con il supporto per l'espressione concisa delle operazioni di array, come descritto di seguito. Questo metodo funziona bene per piccoli snippet di codice, ma in rari casi quando sono necessari snippet di codice più grandi, utilizziamo la sintassi Python vanilla, che viene sempre introdotta esplicitamente.
Formule
Diamo un'occhiata al funzionamento delle formule in base a un esempio della specifica dot_general
. Uno dei vincoli per questa operazione è simile al seguente:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
I nomi utilizzati in questa formula provengono da due origini: 1) funzioni globali, ovvero dim
, 2) definizioni dei membri dell'elemento del programma corrispondente, ovvero gli input lhs
, lhs_batching_dimensions
, rhs
e rhs_batching_dimensions
definiti nella sezione "Input" di dot_general
.
Come accennato in precedenza, la sintassi di questa formula è basata su Python con alcune estensioni orientate alla concisione. Per dare un senso alla formula, trasformiamo la sintassi Python vaniglia.
A) In queste formule, stiamo utilizzando =
per rappresentare l'uguaglianza, quindi il primo passaggio per ottenere la sintassi Python è sostituire =
con ==
, come segue: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Inoltre, queste formule supportano i puntini di sospensione (...
) che trasformano le espressioni scalari
in espressioni tensore. In breve, f(xs...)
significa approssimativamente "per ogni
scala x
nel tensore xs
, calcola uno scalare f(x)
, quindi restituisci tutti
questi risultati scalari insieme come risultato tensore". Nella sintassi Python vaniglia, la nostra formula di esempio si trasforma in: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Grazie alle ellissi, spesso è possibile evitare di lavorare al livello dei singoli scalari. Tuttavia, in alcuni casi difficili, potrebbe essere utilizzata una sintassi semi-informale di livello inferiore, come nella formula start_indices[bi0, ..., :, ..., biN]
della specifica gather
. Al servizio della concisione, non forniamo un formalismo esatto per la traduzione di tale sintassi in Python vanilla, nella speranza che sia ancora intuitivamente comprensibile caso per caso.
Facci sapere se alcune formule sembrano opache e proveremo a migliorarle.
Noterai inoltre che le formule usano i puntini di sospensione per espandere tutti i tipi di elenchi, inclusi tensori, elenchi di tensori (che, ad esempio, possono derivare da un numero variabile di tensori) e così via. Questa è un'altra area in cui non forniamo un formalismo esatto (ad es. gli elenchi non fanno nemmeno parte del sistema di tipo StableHLO) e si basano invece sulla comprensibilità intuitiva.
C) L'ultimo mezzo di notazionale degno di nota che utilizziamo è la trasmissione implicita. L'opset StableHLO non supporta la trasmissione implicita, ma le formule lo supportano, anche al servizio della concisione. In breve, se uno scalare viene utilizzato in un contesto in cui è previsto un tensore, lo scalare viene trasmesso nella forma prevista.
Per continuare con l'esempio dot_general
, ecco un altro vincolo:
0 <= lhs_batching_dimensions < rank(lhs)
. Secondo quanto definito nella specifica dot_general
, lhs_batching_dimensions
è un tensore, tuttavia 0
e rank(lhs)
sono valori scalari. Dopo aver applicato la trasmissione implicita, la formula diventerà [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Quando viene applicata a una determinata operazione dot_general
, questa formula viene valutata
in base a un tensore di valori booleani. Quando le formule vengono utilizzate come vincoli, il vincolo viene applicato se la formula restituisce true
o un tensore che ha solo elementi true
.
Nomi.
Nelle formule, l'ambito lessicale include: 1) funzioni globali, 2) definizioni di membri,
3) definizioni locali. Di seguito è riportato un elenco delle funzioni globali. L'elenco delle definizioni degli elementi dipende dall'elemento del programma a cui viene applicata la notazione:
- Per le operazioni, le definizioni dei membri includono i nomi introdotti nelle sezioni "Input" e "Output".
- Per tutto il resto, le definizioni dei membri includono le parti strutturali dell'elemento del programma, chiamate in base ai corrispondenti non terminali EBNF. Nella maggior parte dei casi, i nomi di queste parti strutturali si ottengono convertendo i nomi dei non terminali in maiuscole e minuscole (ad es.
IntegerLiteral
=>integer_literal
), ma a volte i nomi vengono abbreviati (ad es.QuantizationStorageType
=>storage_type
). In questo caso, i nomi vengono introdotti in modo esplicito in modo simile alle sezioni "Input" / "Output" nelle specifiche delle operazioni. - Inoltre, le definizioni dei membri includono sempre
self
per fare riferimento all'elemento di programma corrispondente.
Valori
Quando le formule vengono valutate, funzionano con i seguenti tipi di valori:
1) Value
(valori effettivi, ad es. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
saranno sempre noti i tipi),
2) Placeholder
(valori futuri, ad es. lhs
, rhs
o result
; i valori effettivi non sono ancora noti, ma sono noti solo i tipi),
3) Type
(tipi definiti nella sezione "Tipi"),
4) Function
(funzioni globali come definite nella sezione "Funzione").
A seconda del contesto, i nomi potrebbero fare riferimento a valori diversi. Più precisamente, la sezione "Semantica" per le operazioni (e gli equivalenti per altri elementi del programma) definisce la logica di runtime, pertanto tutti gli input sono disponibili come Value
.
Al contrario, la sezione "Vincoli" per le operazioni (e gli equivalenti) definisce la logica "tempo di compilazione", ovvero qualcosa che in genere viene eseguito prima del runtime, quindi solo gli input costanti sono disponibili come Value
e gli altri input sono disponibili solo come Placeholder
.
Nomi. | In "Semantica" | In "Vincoli" |
---|---|---|
Funzioni globali | Function |
Function |
Input costanti | Value |
Value |
Input non costanti | Value |
Placeholder |
Output | Value |
Placeholder |
Definizioni locali | Dipende dalla definizione | Dipende dalla definizione |
Prendiamo in considerazione un'operazione transpose
di esempio:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Per questa operazione, permutation
è una costante, quindi è disponibile come Value
sia nella semantica sia nei vincoli. Al contrario, operand
e result
sono
disponibili come Value
nella semantica, ma solo come Placeholder
nei vincoli.
Funzioni
Costruzione dei tipi
Non esistono funzioni che possono essere utilizzate per creare i tipi. Utilizziamo invece
direttamente la sintassi dei tipi perché è in genere più concisa. Ad esempio,
(tensor<E>, tensor<E>) -> (tensor<E>)
anziché function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funzioni sui tipi
element_type
è definito sui tipi di tensore e sui tipi di tensore quantizzato e restituisce, rispettivamente, la parteTensorElementType
oQuantizedTensorElementType
del valoreTensorType
oQuantizedTensorType
corrispondente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
controlla se è possibile promuovere il tipox
al tipoy
. Quandox
ey
sonoQuantizedTensorElementType
, la promozione viene applicata solo astorage_type
. Questa versione specifica della promozione viene attualmente utilizzata nel contesto del calcolo della riduzione (fai riferimento a RFC per ulteriori dettagli).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponibile per tutti i tipi. Ad esempio,is_float(x)
restituiscetrue
sex
è unFloatType
. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia peris_type_name(type(x))
.max_value(x: Type) -> Value
restituisce il valore massimo diTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.min_value(x: Type) -> Value
restituisce il valore minimo possibile di unTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponibile per tutte le definizioni per i membrimember_name
di tutti i tipi. Ad esempio,tensor_element_type(x)
restituisce la parteTensorElementType
di unTensorType
corrispondente. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia permember_name(type(x))
. Sex
non è un tipo con un membro appropriato oppure un valore o un segnaposto di questo tipo, restituisceNone
.
Costruzione dei valori
operation_name(*xs: Value | Type) -> Value
. Disponibile per tutte le operazioni. Ad esempio,add(lhs, rhs)
prende i due valori tensorelhs
erhs
e restituisce l'output della valutazione dell'operazioneadd
con questi input. Per alcune operazioni, ad esempiobroadcast_in_dim
, i tipi di output sono "load-bearing", ovvero necessari per valutare un'operazione. In questo caso, la funzione prende questi tipi come argomenti.
Funzione sui valori
Sono disponibili tutti gli operatori e le funzioni di Python. Ad esempio, sia le notazioni di abbonamento che slicing di Python sono disponibili per l'indicizzazione in tensori, tensori quantizzati e tuple.
to_destination_type(x: Value, destination_type: Type) -> Value
viene definito sui tensori e restituisce il valore convertito dix
in base atype(x)
edestination_type
come segue:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
È in corso una discussione iniziale sull'unione delle operazioni convert
, uniform_quantize
e uniform_dequantize
(#1576).
Dopo l'unione non abbiamo bisogno della funzione di cui sopra e possiamo utilizzare il nome dell'operazione per convert
.
is_nan(x: Value) -> Value
è definito sui tensori e restituiscetrue
se tutti gli elementi dix
sonoNaN
ofalse
, altrimenti. Sex
non è un tensore, restituisceNone
.is_sorted(x: Value) -> Value
è definito sui tensori e restituiscetrue
se gli elementi dix
sono ordinati in ordine crescente rispetto all'ordine lessicografico crescente dei relativi indici ofalse
negli altri casi. Sex
non è un tensore, restituisceNone
.is_unique(x: Value) -> Value
viene definito sui tensori e restituiscetrue
sex
non ha elementi duplicati ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.member_name(x: Value) -> Any
è definito per tutte le definizioni dei membrimember_name
di tutti i valori. Ad esempio,real_part(x)
restituisce la parteRealPart
di unComplexConstant
corrispondente. Sex
non è un valore con un membro appropriato, restituisceNone
.same(x: Value) -> Value
è definito sui tensori e restituiscetrue
se gli elementi dix
sono tutti uguali tra loro, altrimentifalse
. Se il tensore non ha elementi, viene conteggiato come "tutti uguali tra loro", ad esempio la funzione restituiscetrue
. Sex
non è un tensore, restituisceNone
.split(x: Value, num_results: Value, axis: Value) -> Value
è definito sui tensori e restituiscenum_results
sezioni dix
lungo l'asseaxis
. Sex
non è un tensore odim(x, axis) % num_results != 0
, restituisceNone
.
Calcoli delle forme
axes(x: Value | Placeholder | Type) -> Value
è una scorciatoia perrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
è una scorciatoia pershape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
è una scorciatoia perlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
viene definito sui tensori e restituisce gli indicisize(x)
per il valoreTensorType
corrispondente ordinato in ordine lessicografico crescente, ad esempio[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Sex
non è un tipo di tensore, un tipo di tensore quantizzato oppure un valore o un segnaposto di uno di questi tipi, restituisceNone
.rank(x: Value | Placeholder | Type) -> Value
è una scorciatoia persize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
è definito nella sezione "Funzioni sui tipi" tramitemember_name
.size(x: Value | Placeholder | Type) -> Value
è una scorciatoia perreduce(lambda x, y: x * y, shape(x))
.
Calcoli di quantizzazione
def baseline_element_type(x: Value | Placeholder | Type) -> Type
è una scorciatoia perelement_type(baseline_type(x))
.baseline_type
viene definito sui tipi di tensori e sui tipi di tensori quantizzati e li trasforma in una "base di riferimento", ovvero un tipo con la stessa forma ma con i parametri di quantizzazione del tipo di elemento reimpostati sui valori predefiniti. Questo metodo è utile per confrontare in modo uniforme i tipi di tensore e quantizzato, il che è necessario con una certa frequenza. Per i tipi quantizzati, questo consente di confrontare i tipi ignorando i parametri di quantizzazione, ovveroshape
,storage_type
,expressed_type
,storage_min
,storage_max
equantization_dimension
(per il tipo quantizzato per asse) devono corrispondere tutti, mascales
ezero points
potrebbero differire.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
è definito su tipi di tensori quantizzati e li trasforma in tipi di tensori in virgola mobile. Ciò avviene convertendo gli elementi quantizzati che rappresentano i valori interi del tipo di archiviazione in valori in virgola mobile corrispondenti del tipo espresso utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
è definito sui tipi di tensori in virgola mobile e li trasforma in tipi di tensori quantizzati. Ciò avviene convertendo i valori in virgola mobile del tipo espresso in valori interi corrispondenti del tipo di archiviazione utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
viene utilizzato per specificare i calcoli a livello di elemento sui tensori quantizzati. Dequantizza, ad esempio trasforma gli elementi quantizzati nei loro tipi espressi, quindi esegue un'operazione e quindi quantizza, ad esempio trasforma i risultati nei rispettivi tipi di archiviazione. Al momento, questa funzione funziona solo per la quantizzazione per tensore. La quantizzazione per asse è in corso (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
Calcoli a griglia
cross_partition(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica" in alto.cross_replica(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica" in alto.cross_replica_and_partition(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica_and_partition" sopra.flattened_ids(replica_groups: Value) -> Value
. Consulta la sezione "flattened_id" in alto.