StableHLO è un insieme di operazioni per le operazioni di alto livello (HLO) nei modelli di machine learning (ML). StableHLO funziona come livello di portabilità tra diversi framework ML e compilatori ML: i framework ML che producono programmi StableHLO sono compatibili con i compilatori ML che utilizzano programmi StableHLO.
Il nostro obiettivo è semplificare e accelerare lo sviluppo di ML creando una maggiore interoperabilità tra vari framework di ML (come TensorFlow, JAX e PyTorch) e compilatori di ML (come XLA e IREE). A questo scopo, il documento fornisce una specifica per il linguaggio di programmazione StableHLO.
Questa specifica contiene tre sezioni principali. Innanzitutto, la sezione Programmi descrive la struttura dei programmi StableHLO, composti da funzioni StableHLO, a loro volta costituite da operazioni StableHLO. All'interno di questa struttura, la sezione Op specifica la semantica delle singole operazioni. La sezione Esecuzione fornisce la semantica per tutte queste operazioni che vengono eseguite insieme all'interno di un programma. Infine, la sezione Notazione illustra la notazione utilizzata nella specifica.
Per visualizzare le specifiche di una release precedente di StableHLO, apri il repository nella release con tag di tuo interesse. Ad esempio, la specifica StableHLO v0.19.0. Per visualizzare le modifiche apportate a ogni aggiornamento minore di StableHLO, consulta il log della versione in VhloDialect.td.
Programmi
Program ::= {Func}
I programmi StableHLO sono costituiti da un numero arbitrario di funzioni StableHLO.
Di seguito è riportato un programma di esempio con una funzione @main
che ha 3 input
(%image
, %weights
e %bias
) e 1 output. Il corpo della funzione
ha 6 operazioni.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funzioni
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Le funzioni StableHLO (chiamate anche funzioni con nome) hanno un identificatore, input/output e un corpo. In futuro, prevediamo di introdurre metadati aggiuntivi per le funzioni per ottenere una migliore compatibilità con HLO (#425, #626, #740, #744).
Identificatori
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Gli identificatori StableHLO sono simili a quelli di molti linguaggi di programmazione, con due peculiarità: 1) tutti gli identificatori hanno sigilli che distinguono i diversi tipi di identificatori, 2) gli identificatori di valore possono essere completamente numerici per semplificare la generazione di programmi StableHLO.
Tipi
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
I tipi StableHLO sono classificati in tipi di valore (chiamati anche tipi di prima classe) che rappresentano i valori StableHLO e tipi non di valore che descrivono altri elementi del programma. I tipi di HLO stabili sono simili a quelli in molti linguaggi di programmazione, ma la loro peculiarità principale è rappresentata dalla natura specifica del dominio di StableHLO, che genera risultati insoliti (ad esempio, i tipi scalari non sono tipi di valore).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
I tipi di tensori rappresentano i tensori, ovvero gli array multidimensionali. Hanno una forma e un tipo di elemento, dove una forma rappresenta dimensioni non negative o sconosciute nell'ordine crescente delle dimensioni corrispondenti (chiamate anche assi) numerate da 0
a R-1
. Il
numero di dimensioni R
è chiamato rango. Ad esempio, tensor<2x3xf32>
è un tipo di tensore con forma 2x3
e tipo di elemento f32
. Ha due dimensioni
(o, in altre parole, due assi), la dimensione 0 e la dimensione 1, le cui dimensioni
sono 2 e 3. Il suo ranking è 2.
Le forme possono essere parzialmente o completamente sconosciute (dinamiche), ad esempio tensor<?x2xf64>
è parzialmente sconosciuta e tensor<?x?xf64>
è completamente sconosciuta. Le dimensioni dinamiche sono rappresentate utilizzando un ?
. Le forme non possono essere non classificate.
In futuro, prevediamo di estendere i tipi di tensori oltre le dimensioni e i tipi di elementi, ad esempio per includere i layout (#629) e la sparsità (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nome | Tipo | Vincoli |
---|---|---|
storage_type |
tipo di numero intero | (C1-C3), (C8) |
storage_min |
costante intera | (C1), (C3), (C7) |
storage_max |
costante intera | (C2), (C3), (C7) |
expressed_type |
tipo con virgola mobile | (C4) |
quantization_dimension |
costante intera facoltativa | (C10-C12) |
scales |
numero variadico di costanti in virgola mobile | (C4-C6), (C9), (C10), (C13) |
zero_points |
numero variadico di costanti intere | (C7-C9) |
I tipi di elementi quantizzati rappresentano i valori interi di un tipo di archiviazione nell'intervallo compreso tra storage_min
e storage_max
(inclusi) che corrispondono ai valori in virgola mobile di un tipo espresso. Per un determinato valore intero i
,
il corrispondente valore in virgola mobile f
può essere calcolato come
f = (i - zero_point) * scale
, dove scale
e zero_point
sono chiamati
parametri di quantizzazione. storage_min
e storage_max
sono facoltativi nella grammatica, ma hanno rispettivamente i valori predefiniti min_value(storage_type)
e max_value(storage_type)
. I tipi di elementi quantizzati hanno i seguenti vincoli:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Se
is_empty(quantization_dimension)
, allorasize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Al momento, QuantizationScale
è una costante a virgola mobile, ma c'è un forte interesse per le scale basate su numeri interi, rappresentate con moltiplicatori e spostamenti. Abbiamo in programma di esplorare questa opzione nel prossimo futuro
(#1404).
È in corso una discussione sulla semantica di QuantizationZeroPoint
, incluso il tipo, i valori e se possono esserci uno o più punti zero in un tipo di tensore quantizzato. In base ai risultati di questa discussione, la specifica relativa ai punti zero potrebbe cambiare in futuro (#1405).
Un'altra discussione in corso riguarda la semantica di QuantizationStorageMin
e QuantizationStorageMax
per determinare se debbano essere imposti vincoli su questi valori e sui valori dei tensori quantizzati
(#1406).
Infine, prevediamo di esplorare la rappresentazione di scale e punti zero sconosciuti, in modo simile a come prevediamo di esplorare la rappresentazione delle dimensioni di dimensioni sconosciute (#1407).
I tipi di tensori quantizzati rappresentano i tensori con elementi quantizzati. Questi tensori sono esattamente come i tensori normali, ad eccezione del fatto che i loro elementi hanno tipi di elementi quantizzati anziché tipi di elementi normali.
Nei tensori quantizzati, la quantizzazione può essere per tensore, ovvero avere
un scale
e zero_point
per l'intero tensore o essere per asse,
il che significa avere più scales
e zero_points
, una coppia per sezione di
una particolare dimensione quantization_dimension
. In modo più formale, in un tensore t
con quantizzazione per asse, ci sono dim(t, quantization_dimension)
sezioni
del quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
,
ecc. Tutti gli elementi del i
esimo slice utilizzano scales[i]
e zero_points[i]
come
i loro parametri di quantizzazione. I tipi di tensori quantizzati hanno i seguenti vincoli:
- Per la quantizzazione per tensore:
- Nessun vincolo aggiuntivo.
- Per la quantizzazione per asse:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
I tipi di token rappresentano i token, ovvero valori opachi prodotti e consumati da alcune operazioni. I token vengono utilizzati per imporre l'ordine di esecuzione sulle operazioni, come descritto nella sezione Esecuzione.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
I tipi di tuple rappresentano tuple, ovvero elenchi eterogenei. Le tuple sono una funzionalità legacy
che esiste solo per la compatibilità con l'HLO. In HLO, le tuple vengono utilizzate per rappresentare input e output con parametri variabili. In StableHLO, gli input e gli output variabili sono supportati in modo nativo e l'unico utilizzo delle tuple in StableHLO è rappresentare in modo completo l'ABI HLO, dove ad esempio T
, tuple<T>
e
tuple<tuple<T>>
possono essere sostanzialmente diversi a seconda di una determinata
implementazione. In futuro, prevediamo di apportare modifiche all'ABI HLO
che potrebbero consentirci di rimuovere i tipi di tuple da StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
I tipi di elementi rappresentano gli elementi dei tipi di tensori. A differenza di molti linguaggi di programmazione, questi tipi non sono di prima classe in StableHLO. Ciò significa che
i programmi SttableHLO non possono rappresentare direttamente i valori di questi tipi.
Di conseguenza, i valori scalari di tipo T
sono idiomatici
con valori tensoriali 0-dimensionali di tipo tensor<T>
.
- Il tipo booleano rappresenta i valori booleani
true
efalse
. - I tipi di numeri interi possono essere con segno (
si
) o senza segno (ui
) e avere una delle larghezze di bit supportate (2
,4
,8
,16
,32
o64
). I tipisiN
con segno rappresentano valori interi da-2^(N-1)
a2^(N-1)-1
inclusi, mentre i tipiuiN
senza segno rappresentano valori interi da0
a2^N-1
inclusi. - I tipi a virgola mobile possono essere uno dei seguenti:
f8E3M4
,f8E4M3
ef8E5M2
numeri con virgola mobile a 8 bit conformi alle convenzioni IEEE-754.- Tipi
f8E4M3FN
ef8E5M2
corrispondenti rispettivamente alle codificheE4M3
eE5M2
del formato FP8 descritto in Formati FP8 per il deep learning. - Tipi
f8E4M3FNUZ
ef8E5M2FNUZ
corrispondenti alle codificheE4M3
eE5M2
dei formati FP8 descritti in Formati numerici a 8 bit per reti neurali profonde. - Tipo
f8E4M3B11FNUZ
corrispondente alla codificaE4M3
dei formati FP8 descritti in Addestramento in virgola mobile ibrido a 8 bit (HFP8) e inferenza per reti neurali profonde. - Tipo
bf16
corrispondente al formatobfloat16
descritto in BFloat16: il segreto delle alte prestazioni su Cloud TPU. - Tipi
f16
,f32
ef64
corrispondenti ai formati rispettivamentebinary16
("metà precisione"),binary32
("precisione singola") ebinary64
("precisione doppia") descritti nello standard IEEE 754. - Il tipo
tf32
corrisponde al formatoTensorFloat32 e ha un supporto limitato in StableHLO. - Tipi MX (microscalabilità)
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
ef8E8M0FNU
descritti nella Specifica dei formati di microscalabilità di OCP.
- I tipi complessi rappresentano valori complessi che hanno una parte reale e una parte immaginaria dello stesso tipo di elemento. I tipi complessi supportati sono
complex<f32>
(entrambe le parti sono di tipof32
) ecomplex<f64>
(entrambe le parti sono di tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
I tipi di funzioni rappresentano sia le funzioni con nome che quelle anonime. Hanno tipi di input (l'elenco di tipi a sinistra di ->
) e tipi di output (l'elenco di tipi a destra di ->
). In molti linguaggi di programmazione, i tipi di funzione sono di prima classe, ma non in StableHLO.
StringType ::= 'string'
Il tipo di stringa rappresenta sequenze di byte. A differenza di molti linguaggi di programmazione, il tipo di stringa non è di prima classe in StableHLO e viene utilizzato solo per specificare i metadati statici per gli elementi del programma.
Operazioni
Le operazioni StableHLO (chiamate anche ops) rappresentano un insieme chiuso di operazioni di alto livello nei modelli di machine learning. Come discusso sopra, la sintassi di StableHLO è fortemente ispirata a MLIR, che non è necessariamente l'alternativa più ergonomica, ma è probabilmente la più adatta per lo scopo di StableHLO di creare una maggiore interoperabilità tra i framework ML e i compilatori ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Le operazioni SttableHLO (chiamate anche ops) hanno un nome,
input/output e una firma. Il nome è costituito dal prefisso stablehlo.
e da un mnemonico che identifica in modo univoco una delle operazioni supportate. Di seguito è riportato un elenco completo di tutte le operazioni supportate.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Le operazioni consumano input e producono output. Gli input sono classificati in valori di input (calcolati durante l'esecuzione), funzioni di input (fornite in modo statico, perché in StableHLO le funzioni non sono valori di prima classe) e attributi di input (forniti anche in modo statico). Il tipo di input e output
consumati e prodotti da un'operazione dipende dal relativo mnemonico. Ad esempio, l'operatore add
utilizza 2 valori di input e produce 1 valore di output. In confronto, l'operazione select_and_scatter
richiede 3 valori di input, 2 funzioni di input e 3 attributi di input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Le funzioni di input (chiamate anche funzioni anonime) sono molto simili alle funzioni con nome, tranne per il fatto che: 1) non hanno un identificatore (da qui il nome "anonimo"), 2) non dichiarano i tipi di output (i tipi di output sono dedotti dall'operazione return
all'interno della funzione).
La sintassi per le funzioni di input include una parte attualmente non utilizzata (vedi la produzione Unused
sopra) che è presente per la compatibilità con MLIR. In MLIR,
esiste un concetto più generale di "regioni" che possono avere più "blocchi"
di operazioni collegate tra loro tramite operazioni di salto. Questi blocchi hanno ID che corrispondono
all'ambiente di produzione Unused
, in modo che possano essere distinti tra loro.
StableHLO non ha operazioni di salto, quindi la parte corrispondente della sintassi MLIR non viene utilizzata (ma è ancora presente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Gli attributi di input hanno un nome e un valore che è una delle costanti supportate. Sono il modo principale per specificare metadati statici per gli elementi del programma. Ad esempio, l'operatore concatenate
utilizza l'attributo dimension
per
specificare la dimensione lungo la quale vengono concatenati i relativi valori di input. Allo stesso modo,
l'operazione slice
utilizza più attributi, come start_indices
e limit_indices
,
per specificare i limiti utilizzati per suddividere il valore di input.
Al momento, i programmi StableHLO in uso a volte contengono attributi che non sono descritti in questo documento. In futuro, prevediamo di integrare questi attributi nell'insieme di operazioni StableHLO o di vietarne la visualizzazione nei programmi StableHLO. Nel frattempo, ecco l'elenco di questi attributi:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadati sulla località (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma dell'operazione è costituita dai tipi di tutti i valori di input (l'elenco dei tipi sul lato sinistro di ->
) e dai tipi di tutti i valori di output (l'elenco dei tipi sul lato destro di ->
). Rigorosamente parlando, i tipi di input sono ridondanti e anche i tipi di output sono quasi sempre ridondanti (perché per la maggior parte delle operazioni StableHLO, i tipi di output possono essere dedotti dagli input). Tuttavia, la firma op fa deliberatamente parte della sintassi di StableHLO per la compatibilità con MLIR.
Di seguito è riportato un esempio di operazione il cui mnemonico è select_and_scatter
. Utilizza 3 valori di input (%operand
, %source
e %init_value
), 2 funzioni di input e 3 attributi di input (window_dimensions
, window_strides
e padding
). Nota come la firma dell'operazione includa solo i tipi dei relativi valori di input (ma non i tipi di funzioni e attributi di input forniti in linea).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Costanti
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Le costanti StableHLO hanno un valore letterale e un tipo che insieme rappresentano un valore StableHLO. In genere, il tipo fa parte della sintassi della costante, tranne quando è inequivocabile (ad es. una costante booleana ha inequivocabile il tipo i1
, mentre una costante intera può avere più tipi possibili).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Le costanti booleane rappresentano i valori booleani true
e false
. Le costanti booleane hanno tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Le costanti intere rappresentano valori interi tramite stringhe che utilizzano la notazione decimale o esadecimale. Altre basi, ad esempio binarie o ottali, non sono supportate. Le costanti intere hanno i seguenti vincoli:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Le costanti con virgola mobile rappresentano valori con virgola mobile tramite stringhe che utilizzano la notazione decimale o scientifica. Inoltre, la notazione esadecimale può essere utilizzata per specificare direttamente i bit sottostanti nel formato a virgola mobile del tipo corrispondente. Le costanti in virgola mobile presentano i seguenti vincoli:
- (C1) Se viene utilizzata la notazione non esadecimale,
is_wellformed(float_literal, float_type)
. - (C2) Se viene utilizzata la notazione esadecimale,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Le costanti complesse rappresentano valori complessi utilizzando elenchi di una parte reale (viene visualizzata per prima) e una parte immaginaria (viene visualizzata per seconda). Ad esempio,
(1.0, 0.0) : complex<f32>
rappresenta 1.0 + 0.0i
e
(0.0, 1.0) : complex<f32>
rappresenta 0.0 + 1.0i
. L'ordine in cui queste parti vengono archiviate in memoria è definito dall'implementazione. Le costanti complesse hanno i seguenti vincoli:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Le costanti tensore rappresentano i valori dei tensori utilizzando elenchi nidificati specificati tramite la notazione NumPy. Ad esempio, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
rappresenta un valore tensoriale con la seguente mappatura dagli indici agli elementi:{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,{1, 2} => 6
. L'ordine in cui questi elementi vengono archiviati in memoria è definito dall'implementazione. Le costanti tensore hanno i seguenti vincoli:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, dove:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, dove:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- altrimenti
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Le costanti di tensore quantizzati rappresentano i valori dei tensori quantizzati utilizzando la stessa notazione delle costanti di tensore, con elementi specificati come costanti del loro tipo di archiviazione. Le costanti tensore quantizzate hanno i seguenti vincoli:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Le lettere di stringa sono costituite da byte specificati utilizzando caratteri ASCII e sequenze di escape. Sono indipendenti dalla codifica, pertanto l'interpretazione di questi
byte è definita dall'implementazione. I valori letterali stringa sono di tipo string
.
Operazioni
abs
Semantica
Esegue l'operazione di valore assoluto elemento per elemento sul tensore operand
e produce un
result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per numeri interi con segno: modulo intero.
- Per i valori float:
abs
da IEEE-754. - Per i numeri complessi: modulo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(abs, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numero intero firmato, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero con segno o con virgola mobile o tensore quantizzato per tensore | (C1-C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
è definito come:complex_element_type(element_type(operand))
seis_complex(operand)
.baseline_element_type(operand)
in caso contrario.
Esempi
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semantica
Esegue l'addizione elemento per elemento di due tensori lhs
e rhs
e produce un
result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per gli interi: addizione di numeri interi.
- Per i valori float:
addition
da IEEE-754. - Per i numeri complessi: addizione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato | (C1-C6) |
(I2) | rhs |
tensore o tensore quantizzato | (C1-C5), (C7) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C7) |
Vincoli
- Se l'operazione utilizza tensori non quantizzati:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Se l'operazione utilizza tensori quantizzati:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Se
is_per_axis_quantized(lhs)
, alloraquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Se
is_per_axis_quantized(rhs)
, alloraquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantica
Garantisce che le operazioni che producono inputs
vengano eseguite prima di qualsiasi
operazione che dipende da result
. L'esecuzione di questa operazione non fa nulla,
esiste solo per stabilire le dipendenze dei dati da result
a inputs
.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di token |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantica
All'interno di ogni gruppo di processi nella griglia di processi StableHLO, concatena i valori
dei tensori operands
di ogni processo lungo all_gather_dim
e produce
i tensori results
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Successivamente, all'interno di ogni process_group
:
operands...@receiver = [operand@sender for sender in process_group]
per tuttireceiver
inprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
per tuttiprocess
inprocess_group
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C6) |
(I2) | all_gather_dim |
costante di tipo si64 |
(C1), (C6) |
(I3) | replica_groups |
Costante tensoriale 2D di tipo si64 |
(C2-C4) |
(I4) | channel_id |
costante di tipo si64 |
(C5) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C6) |
Vincoli
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
è definito come:num_replicas
se si utilizzacross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C6)
type(results...) = type(operands...)
eccetto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantica
All'interno di ogni gruppo di processi nella griglia di processi StableHLO, applica una funzione di riduzione computation
ai valori dei tensori operands
di ogni processo e produce tensori results
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
In seguito, all'interno di ogni process_group
:
results...@process[result_index] = exec(schedule)
per un albero binarioschedule
dove:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
è un albero binario definito dall'implementazione il cui attraversamento dell'ordine èto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C5), (C6) |
(I2) | replica_groups |
numero variadico di costanti tensoriali 1-dimensionali di tipo si64 |
(C1-C3) |
(I3) | channel_id |
costante di tipo si64 |
(C4) |
(I4) | use_global_device_ids |
costante di tipo i1 |
(C4) |
(I5) | computation |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C6-C7) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2) Per
size(replica_groups)
si intende:num_replicas
se viene utilizzatocross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C5)
computation
è di tipo(tensor<E>, tensor<E>) -> (tensor<E>)
doveis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantica
All'interno di ogni gruppo di processi nella griglia di processi StableHLO, suddivide i valori dei tensori operands
lungo split_dimension
in parti, distribuisce le parti suddivise tra i processi, concatena le parti sparse lungo concat_dimension
e produce tensori results
.
L'operazione divide la griglia di processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
In seguito, all'interno di ogni process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
per tutti isender
inprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
dovereceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C3), (C9) |
(I2) | split_dimension |
costante di tipo si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
costante di tipo si64 |
(C3), (C9) |
(I4) | split_count |
costante di tipo si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Costante tensoriale 2D di tipo si64 |
(C5-C8) |
(I6) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C9) |
Vincoli
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
è definito come:num_replicas
se si utilizzacross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
tranne sesplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
e
Semantica
Esegue l'operatore AND a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per i numeri interi: AND a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantica
Esegue l'operazione atan2 elemento per elemento sui tensori lhs
e rhs
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
atan2
dallo standard IEEE-754. - Per i numeri complessi: atan2 complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantica
Calcola i gradienti di diversi input di retropropagazione di batch_norm_training
da grad_output
e produce tensori grad_operand
, grad_scale
e grad_offset
. Più formalmente, questa operazione può essere espressa come una decomposizione nelle operazioni StableHLO esistenti utilizzando la sintassi Python come segue:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1-C3), (C5) |
(I2) | scale |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4), (C5) |
(I3) | mean |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4) |
(I4) | variance |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4) |
(I5) | grad_output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
grad_operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
grad_scale |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4) |
grad_offset |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
egrad_offset
hanno lo stessobaseline_element_type
. - (C3)
operand
,grad_output
egrad_operand
hanno la stessa forma. - (C4)
scale
,mean
,variance
,grad_scale
egrad_offset
hanno la stessa forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantica
Normalizza il tensore operand
in tutte le dimensioni tranne la dimensione feature_index
e produce un tensore result
. Più formalmente, questa operazione può essere espressa come una decomposizione nelle operazioni StableHLO esistenti utilizzando la sintassi Python come segue:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C7) |
(I2) | scale |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C3) |
(I3) | offset |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C4) |
(I4) | mean |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C5) |
(I5) | variance |
Tensore 1D di tipo quantizzato a virgola mobile o per tensore | (C2), (C6) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C2), (C7) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
eresult
hanno lo stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantica
Calcola la media e la varianza in tutte le dimensioni tranne la dimensione feature_index
e normalizza il tensore operand
producendo i tensori output
, batch_mean
e batch_var
. In modo più formale, questa operazione può essere espressa come una decomposizione delle operazioni StableHLO esistenti utilizzando la sintassi Python come segue:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | scale |
Tensore 1D di tipo a virgola mobile o quantizzato per tensore | (C2), (C3) |
(I3) | offset |
Tensore 1D di tipo a virgola mobile o quantizzato per tensore | (C2), (C4) |
(I4) | epsilon |
costante di tipo f32 |
(C1), (C3-C6) |
(I5) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C7) |
batch_mean |
Tensore 1D di tipo a virgola mobile o quantizzato per tensore | (C2), (C5) |
batch_var |
Tensore 1D di tipo a virgola mobile o quantizzato per tensore | (C2), (C6) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
eoutput
hanno lo stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantica
Esegue un'operazione di bitcast sul tensore operand
e produce un tensore result
dove i bit dell'intero tensore operand
vengono reinterpretati utilizzando il
tipo del tensore result
.
In modo più formale, dati E = element_type(operand)
, E' = element_type(result)
e R = rank(operand)
:
- Se
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Se
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Se
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
restituisce la rappresentazione in memoria di un determinato valore e il suo comportamento è definito dall'implementazione perché la rappresentazione esatta dei tensori è definita dall'implementazione e anche la rappresentazione esatta dei tipi di elementi è definita dall'implementazione.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C2) |
Vincoli
- (C1) Dati
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
eR = rank(operand)
:- Se
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Se
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Se
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Se
- (C2) Se
is_complex(operand) or is_complex(result)
, allorais_complex(operand) and is_complex(result)
.
Esempi
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantica
Espande le dimensioni e/o il ranking di un tensore di input duplicando i dati
nel tensore operand
e produce un tensore result
. Più formalmente,
result[result_index] = operand[operand_index]
dove per tutti i d
in
axes(operand)
:
operand_index[d] = 0
sedim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C6) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
ad eccezione dei seguenti casi:quantization_dimension(operand)
,scales(operand)
ezero_points(operand)
possono differire daquantization_dimension(result)
,scales(result)
ezero_points(result)
risp., altrimenti.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Per tutti i valori
d
inaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se
dim(operand, quantization_dimension(operand)) = 1
, thenscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Esempi
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
richiesta
Semantica
Genera l'output dall'esecuzione esattamente di una funzione da branches
a seconda del valore di index
. Più formalmente, result = selected_branch()
dove:
selected_branch = branches[index]
se0 <= index < size(branches)
.selected_branch = branches[-1]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | index |
Tensore di dimensione 0 di tipo si32 |
|
(I2) | branches |
numero variadico di funzioni | (C1-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C4) |
Vincoli
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Esempi
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
Semantica
Esegue l'operazione di radice cubica elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
rootn(x, 3)
da IEEE-754. - Per i numeri complessi: radice cubica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(cbrt, operand, type(result))
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantica
Esegue il tetto elemento per elemento del tensore operand
e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardPositive
della specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(ceil, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semantica
Calcola la decomposizione di Cholesky di un batch di matrici.
Più formalmente, per tutti i valori i
in index_space(result)
,
result[i0, ..., iR-3, :, :]
è una decomposizione di Cholesky di
a[i0, ..., iR-3, :, :]
, sotto forma di matrice triangolare inferiore
(se lower
è true
) o triangolare superiore (se lower
è false
).
I valori di output nel triangolo opposto, ovvero il triangolo superiore rigido o
il triangolo inferiore stretto corrispondentemente, sono definiti dall'implementazione.
Se esiste i
in cui la matrice di input non è una matrice hermitiana positiva definita, il comportamento non è definito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | lower |
Costante tensoriale di dimensione 0 di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Esempi
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
clampare
Semantica
Fissa ogni elemento del tensore operand
tra un valore minimo e massimo
e produce un tensore result
. In modo più formale, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
dove min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Per i tipi quantizzati,
esegue dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Imporre un ordinamento su numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo in programma di rimuovere il supporto dei numeri complessi per questa operazione (#560).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | min |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
(I2) | operand |
tensore quantizzato per tensore o per tensore | (C1-C4) |
(I3) | max |
tensore o tensore quantizzato per tensore | (C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C4) |
Vincoli
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, invia il valore del tensore operand
dal processo di origine ai processi di destinazione e produce un tensore result
.
L'operazione divide la griglia di processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
Successivamente, result@process
è dato da:
operand@process_groups[i, 0]
se esiste uni
tale che il processo sia inprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C3) |
(I2) | replica_groups |
numero variadico di costanti tensoriali 1-dimensionali di tipo si64 |
(C1), (C2) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
doveN
è definito come:num_replicas
se viene utilizzatocross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C3)
type(result) = type(operand)
.
Esempi
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantica
All'interno di ogni gruppo di processi nella griglia di processi StableHLO, invia il valore del
operand
tensore dal processo di origine al processo di destinazione e produce un
operand
tensore.result
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(source_target_pairs)
sechannel_id <= 0
.cross_partition(source_target_pairs)
sechannel_id > 0
.
Successivamente, result@process
è dato da:
operand@process_groups[i, 0]
, se esiste uni
tale cheprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C5) |
(I2) | source_target_pairs |
Costante tensoriale 2D di tipo si64 |
(C1-C4) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, doveN
è definito come:num_replicas
se si utilizzacross_replica
.num_partitions
se viene utilizzatocross_partition
.
- (C5)
type(result) = type(operand)
.
Esempi
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
confronta
Semantica
Esegue il confronto elemento per elemento dei tensori lhs
e rhs
in base a comparison_direction
e compare_type
e produce un tensore result
.
I valori di comparison_direction
e compare_type
hanno la seguente
semantica:
Per i tipi di elementi booleani e interi:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Per i tipi di elementi in virgola mobile con compare_type = FLOAT
, l'operazione implementa
le seguenti operazioni IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Per i tipi di elementi in virgola mobile con compare_type = TOTALORDER
, l'operatore
utilizza la combinazione di operazioni totalOrder
e compareQuietEqual
da
IEEE-754.
Per i tipi di elementi complessi, il confronto lessicografico delle coppie (real, imag)
viene eseguito utilizzando gli attributi comparison_direction
e compare_type
forniti.
Imporre un ordinamento su numeri complessi comporta una semantica sorprendente,
quindi in futuro abbiamo in programma di rimuovere il supporto per i numeri complessi quando comparison_direction
è GE
, GT
, LE
o LT
(#560).
Per i tipi quantizzati, esegue dequantize_compare(lhs, rhs,
comparison_direction)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1-C3) |
(I2) | rhs |
tensore o tensore quantizzato per tensore | (C1-C2) |
(I3) | comparison_direction |
enum di EQ , NE , GE , GT , LE e LT |
|
(I4) | compare_type |
enum di FLOAT , TOTALORDER , SIGNED e UNSIGNED |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano | (C2) |
Vincoli
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
è definito come:SIGNED
seis_signed_integer(element_type(lhs))
.UNSIGNED
seis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
seis_float(element_type(lhs))
.FLOAT
seis_complex(element_type(lhs))
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complesso
Semantica
Esegue la conversione elemento per elemento in un valore complesso da una coppia di valori reali e immaginari, lhs
e rhs
, e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo f32 o f64 |
(C1-C3) |
(I2) | rhs |
tensore di tipo f32 o f64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo complesso | (C2), (C3) |
Vincoli
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
ha il tipocomplex<E>
doveE = element_type(lhs)
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
composito
Semantica
Incapsula un'operazione composta da altre operazioni StableHLO,
che prende inputs
e composite_attributes
e produce results
. La
semantica dell'operazione è implementata dall'attributo decomposition
. L'operazione composite
può essere sostituita con la sua decomposizione senza modificare la semantica del programma. Nei casi in cui l'inserimento in linea della decomposizione non fornisca la stessa semantica dell'operatore, preferisci utilizzare custom_call
.
Il campo version
(il valore predefinito è 0
) viene utilizzato per indicare quando cambia la semantica di un composito.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di valori |
(I2) | name |
costante di tipo string |
(I3) | composite_attributes |
dizionario attributi |
(I4) | decomposition |
costante di tipo string |
(I5) | version |
costante di tipo si32 |
Output
Nome | Tipo |
---|---|
results |
numero di valori variadico |
Vincoli
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Esempi
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semantica
Concatena inputs
lungo la dimensione dimension
nello stesso ordine degli argomenti specificati e produce un tensore result
. In forma più formale,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, dove:
id = d0 + ... + dk-1 + kd
.d
è uguale adimension
ed0
, ... sono led
a dimensioni delle dimensioniinputs
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C6) |
(I2) | dimension |
costante di tipo si64 |
(C2), (C4), (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5-C6) |
Vincoli
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
ad eccezione didim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
ad eccezione di:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Esempi
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
costante
Semantica
Produce un tensore output
da una costante value
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | value |
costante | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore o tensore quantizzato | (C1) |
Vincoli
- (C1)
type(value) = type(output)
.
Esempi
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
convertire
Semantica
Esegue una conversione elemento per elemento da un tipo di elemento a un altro sul
operand
tensore e produce un result
tensore.
Per le conversioni boolean-to-any-supported-type, il valore false
viene
convertito in zero e il valore true
in uno. Per le conversioni any-supported-type-to-boolean, un valore pari a zero viene convertito in false
, mentre i valori diversi da zero vengono convertiti in true
. Continua a leggere per sapere come funziona
per i tipi complessi.
Per le conversioni che coinvolgono intero a intero, intero a virgola mobile o virgola mobile a virgola mobile, se il valore di origine può essere rappresentato esattamente nel tipo di destinazione, il valore del risultato è la rappresentazione esatta. In caso contrario, il comportamento è da decidere (#180).
Per le conversioni che coinvolgono floating-point-to-integer, la parte frazionaria viene troncata. Se il valore troncato non può essere rappresentato nel tipo di destinazione, il comportamento è TBD (#180).
La conversione da complesso a complesso segue lo stesso comportamento delle conversioni da virgola mobile a virgola mobile per la conversione delle parti reali e immaginarie.
Per le conversioni da complesso a qualsiasi altro tipo e da qualsiasi altro tipo a complesso, il valore immaginario di origine viene ignorato o il valore immaginario di destinazione viene impostato su zero. La conversione della parte reale segue le conversioni a virgola mobile.
In linea di principio, questa operazione potrebbe esprimere la dequantizzazione (conversione da tensori quantizzati a tensori regolari), la quantizzazione (conversione da tensori regolari a tensori quantizzati) e la ricontizzazione (conversione tra tensori quantizzati), ma al momento abbiamo operazioni dedicate per questo: uniform_dequantize
per il primo caso d'uso e uniform_quantize
per il secondo e il terzo caso d'uso. In futuro, queste due operazioni potrebbero essere unite in convert
(#1576).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore | (C1) |
Vincoli
- (C1)
shape(operand) = shape(result)
.
Esempi
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convoluzione
Semantica
Calcola i prodotti scalari tra finestre di lhs
e sezioni di rhs
e produce
result
. Il seguente diagramma mostra come vengono calcolati gli elementi in result
da
lhs
e rhs
utilizzando un esempio concreto.
Più formalmente, considera la seguente riorganizzazione degli input in termini di lhs
per poter esprimere finestre di lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Questo nuovo inquadramento utilizza le seguenti funzioni di supporto:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
dovej[d] = i[permutation[d]]
.
Se feature_group_count = 1
e batch_group_count = 1
, per tutti
output_spatial_index
in index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
dove:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Questa funzionalità sembra non essere utilizzata, pertanto in futuro prevediamo di rimuoverla (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Se feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Se batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C2-C3), (C25) |
(I4) | padding |
Costante tensoriale 2D di tipo si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Costante tensoriale 1-dimensionale di tipo si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Costante tensoriale 1-dimensionale di tipo i1 |
(C9) |
(I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
costante di tipo si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
numero variadico di enum di DEFAULT , HIGH e HIGHEST |
(C24) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C25-C28), (C30), (C32-34) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dato
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dato
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) In base a
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25) Per
dim(result, result_dim)
si intende:dim(lhs, input_batch_dimension) / batch_group_count
seresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
seresult_dim = output_feature_dimension
.num_windows
altrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, poiquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se
is_per_axis_quantized(result)
, thenquantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
, thenis_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Esempi
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semantica
Esegue l'operazione di coseno elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
cos
da IEEE-754. - Per i numeri complessi: coseno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(cosine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantica
Esegue il conteggio elemento per elemento del numero di bit iniziali pari nel tensore operand
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantica
Incapsula un'operazione call_target_name
definita dall'implementazione che prende inputs
e called_computations
e produce results
. has_side_effect
,
backend_config
e api_version
possono essere utilizzati per fornire metadati aggiuntivi definiti dall'implementazione.
Al momento, questa operazione contiene una raccolta di metadati piuttosto disorganizzata che riflette l'evoluzione organica dell'operazione di controparte nel compilatore XLA. In futuro, prevediamo di unificare questi metadati (#741).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di valori |
(I2) | call_target_name |
costante di tipo string |
(I3) | has_side_effect |
costante di tipo i1 |
(I4) | backend_config |
costante di tipo string o dizionario di attributi |
(I5) | api_version |
costante di tipo si32 |
(I6) | called_computations |
numero variadico di costanti di tipo string |
Output
Nome | Tipo |
---|---|
results |
numero variadico di valori |
Esempi
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
divisione
Semantica
Esegue la divisione elemento per elemento dei tensori dividendo lhs
e divisore rhs
e
produce un tensore result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi: divisione di numeri interi che produce il quoziente algebrico con qualsiasi parte frazionaria scartata.
- Per i valori float:
division
da IEEE-754. - Per i numeri complessi: divisione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantica
Calcola i prodotti scalari tra i vari slice di lhs
e i vari slice di rhs
e produce un
result
tensore.
In forma più formale, result[result_index] = dot_product
, dove:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dovesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
esize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
controlla il compromesso tra velocità e accuratezza per i calcoli sui backend degli acceleratori. Può essere uno dei seguenti valori (al momento, la semantica di questi valori enum è sottospecificata, ma prevediamo di risolvere questo problema in #755):
DEFAULT
: calcolo più rapido, ma approssimazione meno accurata del numero originale.HIGH
: calcolo più lento, ma approssimazione più accurata del numero originale.HIGHEST
: calcolo più lento, ma approssimazione più accurata del numero originale.
Un DotAlgorithm
definisce le proprietà principali dell'algoritmo utilizzato per implementare l'operazione di moltiplicazione, che definisce anche la precisione. Se i campi dell'attributo algoritmo sono impostati, precision_config
deve essere DEFAULT
. DotAlgorithms
non hanno un valore predefinito, poiché i parametri predefiniti sono definiti dall'implementazione. Di conseguenza, tutti i campi dell'algoritmo dei punti possono essere impostati su None
per specificare un algoritmo dei punti vuoto, che utilizzerà invece il valore precision_config
.
I campi DotAlgorithm
includono:
lhs_precision_type
erhs_precision_type
, le precisioni a cui vengono arrotondati il primo e il secondo membro dell'operazione. I tipi di precisione sono indipendenti dai tipi di archiviazione degli input e dell'output.accumulation_type
la precisione utilizzata per l'accumulo.lhs_component_count
,rhs_component_count
enum_primitive_operations
vengono applicati quando utilizziamo un algoritmo che decompone il primo e/o il secondo membro in più componenti ed esegue più operazioni di moltiplicazione "primitive" su questi valori, in genere per emulare una precisione più elevata (ad es. Sfruttare il tipo di dati di IA bfloat16 per calcoli di maggiore precisione: bf16_6x tf32_3x e così via). Per gli algoritmi senza decomposizione, questi valori devono essere impostati su1
.allow_imprecise_accumulation
per specificare se l'accumulo con una precisione inferiore è consentito per alcuni passaggi (ad es.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Attributi DotAlgorithm
di esempio:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Spetta alle implementazioni decidere quali combinazioni sono supportate. In generale, non è garantito che ogni algoritmo sia supportato su ogni tipo di acceleratore dal consumatore di StableHLO. Se un determinato algoritmo non è supportato, deve essere generato un errore anziché passare a un'alternativa. La verifica StableHLO viene eseguita secondo il criterio del massimo impegno, impedendo l'uso di algoritmi non supportati su nessun hardware.
Consulta xla_data.proto > Algorithm
per alcuni valori dell'algoritmo supportati. Il ticket 2483 mostra il piano per la creazione di un documento
centralizzato sugli algoritmi supportati dal backend.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensore o tensore quantizzato | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
numero variadico di enum di DEFAULT , HIGH e HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType o TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType o TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType o TensorFloat32 | (C21) |
(I11) | lhs_component_count |
costante di tipo si32 |
(C21), (C22) |
(I12) | rhs_component_count |
costante di tipo si32 |
(C21), (C23) |
(I13) | num_primitive_operations |
costante di tipo si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
costante di tipo bool |
(C21) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C12), (C14), (C18-C20) |
Vincoli
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Se l'operazione utilizza tensori non quantizzati:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Se l'operazione utilizza tensori quantizzati:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Se
is_per_axis_quantized(rhs)
, significaquantization_dimension(rhs)
non inrhs_contracting_dimensions
. - Se
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Se
is_per_tensor_quantized(rhs)
, thenis_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Se
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Esempi
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantica
Questa operazione è funzionalmente identica all'operazione
broadcast_in_dim, ma la forma del risultato viene specificata dinamicamente tramite output_dimensions
.
L'operazione accetta anche gli attributi facoltativi known_expanding_dimensions
, known_nonexpanding_dimensions
per esprimere conoscenze statiche sul comportamento di espansione delle dimensioni.
Se non specificato, si presume che tutte le dimensioni possano espandersi.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensore 1-dimensionale di tipo intero | (C7) |
(I3) | broadcast_dimensions |
Tensore costante unidimensionale di tipo intero | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensore costante unidimensionale di tipo intero | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
Tensore costante unidimensionale di tipo intero | (C8-C9) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C7) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
ad eccezione dei seguenti casi:quantization_dimension(operand)
,scales(operand)
ezero_points(operand)
possono differire daquantization_dimension(result)
,scales(result)
ezero_points(result)
risp., altrimenti.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Per tutti i valori
d
inaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se
dim(operand, quantization_dimension(operand)) = 1
, thenscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Esempi
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantica
Questa operazione è funzionalmente identica all'operazione op di convoluzione, ma il padding viene specificato dinamicamente tramite padding
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato per tensore | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensore 2D di tipo intero | (C4) |
(I4) | window_strides |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2-C3) |
(I5) | lhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C5-C6) |
(I6) | rhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C7-C8) |
(I7) | window_reversal |
Costante tensoriale 1-dimensionale di tipo i1 |
(C9) |
(I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C17-C18) |
(I14) | output_batch_dimension |
costante di tipo si64 |
(C20) |
(I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C19-C20) |
(I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
numero variadico di enum di DEFAULT , HIGH e HIGHEST |
(C24) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C25-C27), (C29), (C31-C33) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dato
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dato
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) In base a
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25) Per
dim(result, result_dim)
si intende:dim(lhs, input_batch_dimension) / batch_group_count
seresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
seresult_dim = output_feature_dimension
.num_windows
altrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, poiquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se
is_per_axis_quantized(result)
, thenquantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
, thenis_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Esempi
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantica
Questa operazione è funzionalmente identica all'operazione
gather, con slice_sizes
specificato dinamicamente come valore.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensore di tipo intero | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensore monodimensionale di tipo intero | (C8), (C11-C13) |
(I4) | offset_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Costante tensoriale 1-dimensionale di tipo si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
costante di tipo si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
costante di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5), (C13-C14) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dove:batch_dim_sizes = shape(start_indices)
ad eccezione del fatto che la dimensionestart_indices
corrispondente aindex_vector_dim
non è inclusa.offset_dim_sizes = shape(slice_sizes)
, ad eccezione del fatto che le dimensioni delle dimensioni inslice_sizes
corrispondenti acollapsed_slice_dims
non sono incluse.combine
inseriscebatch_dim_sizes
sugli assi corrispondenti abatch_dims
eoffset_dim_sizes
sugli assi corrispondenti aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Esempi
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantica
Questa operazione è identica dal punto di vista funzionale all'operazione iota, ma la forma dei risultati viene specificata in modo dinamico tramite output_shape
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | output_shape |
Tensore monodimensionale di tipo intero | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Esempi
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantica
Questa operazione è funzionalmente identica all'operazione
pad,
ma con edge_padding_low
, edge_padding_high
e interior_padding
specificati dinamicamente come valori.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
(I2) | padding_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
(I3) | edge_padding_low |
Tensore 1-dimensionale di tipo intero | (C1), (C4) |
(I4) | edge_padding_high |
Tensore 1-dimensionale di tipo intero | (C1), (C4) |
(I5) | interior_padding |
Tensore 1-dimensionale di tipo intero | (C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantica
Questa operazione è funzionalmente identica all'operazione
reshape, ma la forma del risultato viene specificata dinamicamente tramite output_shape
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
(I2) | output_shape |
Tensore 1-dimensionale di tipo intero | (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C4) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
ad eccezione del fatto chequantization_dimension(operand)
equantization_dimension(result)
possono differire, altrimenti.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantica
Estrae uno slice da operand
utilizzando gli indici di inizio calcolati dinamicamente
e produce un tensore result
. start_indices
contengono gli indici iniziali del segmento per ogni dimensione soggetta a un potenziale aggiustamento e slice_sizes
contengono le dimensioni del segmento per ogni dimensione. In forma più formale,
result[result_index] = operand[operand_index]
dove:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
(I2) | start_indices |
numero variadico dei tensori 0-dimensionali di tipo intero | (C2), (C3) |
(I3) | slice_sizes |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Esempi
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantica
Genera un tensore result
uguale al tensore operand
, tranne per il fatto che
la sezione che inizia da start_indices
viene aggiornata con i valori in update
.
Più formalmente, result[result_index]
è definito come:
update[update_index]
se0 <= update_index < shape(update)
dove:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C4), (C6) |
(I2) | update |
tensore quantizzato per tensore o per tensore | (C2), (C3), (C6) |
(I3) | start_indices |
numero variadico di tensori di tipo intero a zero dimensioni | (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Esempi
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
esponenziale
Semantica
Esegue un'operazione esponenziale elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
exp
da IEEE-754. - Per i numeri complessi: esponenziale complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
esponenziale_meno_uno
Semantica
Esegue un'operazione esponenziale meno uno elemento per elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento:
- Per i valori float:
expm1
da IEEE-754. - Per i numeri complessi: esponenziale complessa meno uno.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantica
Esegue la trasformata di Fourier diretta e inversa per input/output reali e complessi.
fft_type
corrisponde a uno dei seguenti:
FFT
: inoltro di FFT da complessi a complessi.IFFT
: FFT complesso-complesso inverso.RFFT
: FFT da reale a complesso in avanti.IRFFT
: FFT inverso da reale a complesso (ovvero prende un numero complesso e restituisce un numero reale).
Più formalmente, data la funzione fft
che prende come input tensori monodimensionali di tipi complessi, produce tensori monodimensionali degli stessi tipi dell'output e calcola la trasformata discreta di Fourier:
Per fft_type = FFT
, result
è definito come il risultato finale di una serie di calcoli L
in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Inoltre, data la funzione ifft
che ha la stessa firma del tipo e
calcola l'inverso di fft
:
Per fft_type = IFFT
, result
è definito come l'inverso dei calcoli per fft_type = FFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Inoltre, data la funzione rfft
che accetta tensori 1D di tipi a virgola mobile, produce tensori 1D di tipi complessi della stessa semantica a virgola mobile e funziona come segue:
rfft(real_operand) = truncated_result
dovecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
Quando la trasformata di Fourier discreta viene calcolata per operandi reali, i primi N/2 + 1
elementi del risultato definiscono in modo inequivocabile il resto del risultato, pertanto il risultato di rfft
viene troncato per evitare di calcolare elementi ridondanti.
Per fft_type = RFFT
, result
è definito come il risultato finale di una serie di calcoli L in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Infine, data la funzione irfft
che ha la stessa firma di tipo e calcola l'inverso di rfft
:
Per fft_type = IRFFT
, result
è definito come l'inverso dei calcoli
per fft_type = RFFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum di FFT , IFFT , RFFT e IRFFT |
(C2), (C5) |
(I3) | fft_length |
Costante tensore monodimensionale di tipo si64 |
(C1), (C3), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o complesso | (C2), (C4), (C5) |
Vincoli
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relazione tra i tipi di elementi
operand
eresult
varia:- Se
fft_type = FFT
,element_type(operand)
eelement_type(result)
hanno lo stesso tipo complesso. - Se
fft_type = IFFT
,element_type(operand)
eelement_type(result)
hanno lo stesso tipo complesso. - Se
fft_type = RFFT
,element_type(operand)
è un tipo a virgola mobile eelement_type(result)
è un tipo complesso della stessa semantica a virgola mobile. - Se
fft_type = IRFFT
,element_type(operand)
è un tipo complesso eelement_type(result)
è un tipo a virgola mobile con la stessa semantica a virgola mobile.
- Se
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Se tra
operand
eresult
è presente un tensorereal
di tipo virgola mobile, allorashape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
eccetto:- Se
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Se
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Se
Esempi
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piano
Semantica
Esegue il pavimento a livello di elemento del tensore operand
e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardNegative
dalla specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(floor, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
raccogliere
Semantica
Raccoglie le sezioni del tensore operand
dagli offset specificati in start_indices
e produce un tensore result
.
Il seguente diagramma mostra come gli elementi in result
vengono mappati agli elementi in
operand
utilizzando un esempio concreto. Il diagramma sceglie alcuni esempi di indici result
e spiega in dettaglio a quali indici operand
corrispondono.
In modo più formale, result[result_index] = operand[operand_index]
dove:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
è definito come:start_indices[bi0, ..., :, ..., biN]
dovebi
sono singoli elementi inbatch_index
e:
viene inserito all'indiceindex_vector_dim
, seindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
in caso contrario.
- Per
d_operand
inaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sed_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
in caso contrario.
- Per
d_operand
inaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_operand = operand_batching_dims[i_batching]
ed_start = start_indices_batching_dims[i_batching]
.full_batching_index[d_operand] = 0
in caso contrario.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, doveoi
sono singoli elementi dioffset_index
e0
viene inserito negli indici dicollapsed_slice_dims
eoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Se indices_are_sorted
è true
, l'implementazione può presumere che
start_indices
siano ordinati rispetto a start_index_map
, altrimenti il
comportamento non è definito. Più formalmente, per tutti i i1 < i2
di indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensore di tipo intero | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C13-C17) |
(I7) | start_index_map |
Costante tensoriale 1-dimensionale di tipo si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
costante di tipo si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Costante tensoriale 1-dimensionale di tipo si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
costante di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5), (C22-C23) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dove:batch_dim_sizes = shape(start_indices)
, ad eccezione del fatto che la dimensione dimensionestart_indices
corrispondente aindex_vector_dim
non è inclusa.offset_dim_sizes = slice_sizes
, ad eccezione del fatto che le dimensioni delle dimensioni inslice_sizes
corrispondenti acollapsed_slice_dims
eoperand_batching_dims
non sono incluse.combine
posizionabatch_dim_sizes
sugli assi corrispondenti abatch_dims
eoffset_dim_sizes
sugli assi corrispondenti aoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Esempi
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantica
Genera la dimensione del dimension
specificato del operand
. Più formalmente,
result = dim(operand, dimension)
. La semantica riguarda solo il componente forma del tipo. Il tipo di elemento può essere qualsiasi.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1) |
(I2) | dimension |
costante di tipo si64 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
Tensore 0dimensionale di tipo si32 |
Vincoli
- (C1)
0 <= dimension < rank(operand)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantica
Estrae l'elemento nella posizione index
della tupla operand
e produce un
result
. Più formalmente, result = operand[index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tupla | (C1), (C2) |
(I2) | index |
costante di tipo si32 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
Qualsiasi tipo supportato | (C2) |
Vincoli
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Esempi
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
se
Semantica
Genera l'output dall'esecuzione esattamente di una funzione da true_branch
o
false_branch
a seconda del valore di pred
. In termini più formali, result =
pred ? true_branch() : false_branch()
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
Tensore di dimensione 0 di tipo i1 |
|
(I2) | true_branch |
funzione | (C1-C3) |
(I3) | false_branch |
funzione | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C3) |
Vincoli
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Esempi
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantica
Estrae la parte immaginaria, elemento per elemento, da operand
e produce un
result
tensore. In modo più formale, per ogni elemento x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo floating-point | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2) Per
element_type(result)
si intende:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
infeed
Semantica
Legge i dati dall'infeed e produce results
.
La semantica di infeed_config
è definita dall'implementazione.
results
è costituito da valori di payload che vengono visualizzati per primi e da un token che arriva per ultimo. In futuro, prevediamo di suddividere il payload e il token in due output distinti per maggiore chiarezza (#670).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
costante di tipo string |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C1-C3) |
Vincoli
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Esempi
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semantica
Riempie un tensore output
con valori in ordine crescente a partire da zero
lungo la dimensione iota_dimension
. In termini più formali,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
0 <= iota_dimension < rank(output)
.
Esempi
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantica
Esegue la verifica a livello di elemento se il valore in x
è finito (ovvero non è
+Inf, -Inf, né NaN) e produce un tensore y
. Implementa l'operazione isFinite
della specifica IEEE-754. Per i tipi quantizzati, il risultato è sempre true
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | x |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
y |
tensore di tipo booleano | (C1) |
Vincoli
- (C1)
shape(x) = shape(y)
.
Esempi
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantica
Esegue l'operazione di logaritmo elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i valori float:
log
da IEEE-754. - Per i numeri complessi: logaritmo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(log, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantica
Esegue un'operazione di logaritmo elemento per elemento più uno sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
logp1
da IEEE-754. - Per i numeri complessi: logaritmo complesso più 1.
- Per i tipi quantizzati:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistica
Semantica
Esegue un'operazione logistica elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
division(1, addition(1, exp(-x)))
dallo standard IEEE-754. - Per i numeri complessi: logistica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(logistic, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mappa
Semantica
Applica una funzione di mappatura computation
a inputs
lungo dimensions
e
produce un tensore result
.
In termini più formali, result[result_index] = computation(inputs...[result_index])
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4) |
(I2) | dimensions |
Costante tensore monodimensionale di tipo si64 |
(C3) |
(I3) | computation |
funzione | (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C4) |
Vincoli
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
ha il tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, doveEi = element_type(inputs[i])
eE' = element_type(result)
.
Esempi
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
massimo
Semantica
Esegue l'operazione massima elemento per elemento sui tensori lhs
e rhs
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per gli interi: numero intero massimo.
- Per i numeri in virgola mobile:
maximum
dallo standard IEEE-754. - Per i numeri complessi: massimo lessicografico per la coppia
(real, imaginary)
. L'imposizione di un ordine ai numeri complessi comporta una semantica sorprendente, pertanto in futuro prevediamo di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimo
Semantica
Esegue l'operazione minima elemento per elemento sui tensori lhs
e rhs
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per gli interi: numero intero minimo.
- Per i valori float:
minimum
da IEEE-754. - Per i numeri complessi: minimo lessicografico per la coppia
(real, imaginary)
. L'imposizione di un ordine ai numeri complessi comporta una semantica sorprendente, pertanto in futuro prevediamo di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
moltiplicazione
Semantica
Esegue il prodotto elemento per elemento di due tensori lhs
e rhs
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per gli interi: moltiplicazione di numeri interi.
- Per i valori float:
multiplication
da IEEE-754. - Per i numeri complessi: moltiplicazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negare
Semantica
Esegue la negazione elemento per elemento del tensore operand
e produce un
result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per gli interi con segno: negazione dell'intero.
- Per i numeri interi non firmati: bitcast a numero intero firmato, negazione di un numero intero, bitcast di nuovo a numero intero non firmato.
- Per i valori float:
negate
da IEEE-754. - Per i numeri complessi: negazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(negate, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
non
Semantica
Esegue la NOT elemento per elemento del tensore operand
e produce un tensore result
.
A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: NOT logico.
- Per gli interi: NON a livello di bit.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantica
Garantisce che le operazioni che producono operand
vengano eseguite prima di qualsiasi operazione che dipende da result
e impedisce alle trasformazioni del compilatore di spostare le operazioni attraverso la barriera. A parte questo, l'operazione è
un'identità, ad esempio result = operand
.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
numero variadico di tensori, tensori o token quantiizzati per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
numero variadico di tensori, tensori o token quantiizzati per tensore | (C1) |
Vincoli
- (C1)
type(operand...) = type(result...)
.
Esempi
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semantica
Esegue l'operazione OR elemento per elemento di due tensori lhs
e rhs
e produce un
result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per gli interi: OR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero o booleano | (C1) |
(I2) | rhs |
tensore di tipo intero o booleano | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi o booleani | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
uscita
Semantica
Scrive inputs
nell'outfeed e produce un token result
.
La semantica di outfeed_config
è definita dall'implementazione.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati |
(I2) | token |
token |
(I3) | outfeed_config |
costante di tipo string |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Semantica
Espande operand
inserendo una spaziatura interna attorno al tensore e tra gli elementi
del tensore con l'elemento padding_value
specificato.
edge_padding_low
e edge_padding_high
specificano la quantità di spaziatura interna aggiunta rispettivamente all'estremità inferiore (accanto all'indice 0) e all'estremità superiore (accanto all'indice più alto) di ogni dimensione. La quantità di spaziatura interna può essere negativa, dove il valore assoluto di spaziatura interna negativa indica il numero di elementi da rimuovere dalla dimensione specificata.
interior_padding
specifica la quantità di spaziatura interna aggiunta tra due elementi in ogni dimensione, che non può essere negativa. Il padding interno avviene
prima del padding dei bordi, in modo che il padding dei bordi negativo rimuova gli elementi dall'operando con padding interno.
In modo più formale, result[result_index]
è definito come:
operand[operand_index]
seresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
(I2) | padding_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
(I3) | edge_padding_low |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Costante tensoriale 1-dimensionale di tipo si64 |
(C1), (C4) |
(I5) | interior_padding |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantica
Produce partition_id
del processo corrente.
Output
Nome | Tipo |
---|---|
result |
Tensore di dimensione 0 di tipo ui32 |
Esempi
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semantica
Esegue il conteggio elemento per elemento del numero di bit impostati nel tensore operand
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
potenza
Semantica
Esegue l'esponenzializazione elemento per elemento del tensore lhs
per il tensore rhs
e produce un tensore result
. A seconda del tipo di elemento:
- Per gli interi: potenza di un numero intero.
- Per i valori float:
pow
da IEEE-754. - Per i numeri complessi: esponenziale complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
reale
Semantica
Estrae la parte reale, elemento per elemento, da operand
e produce un tensore result
. In modo più formale, per ogni elemento x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo floating-point | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2) Per
element_type(result)
si intende:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantica
Riceve i dati da un canale con channel_id
e produce results
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati dall'host. In caso contrario, trasferisce i dati da un altro dispositivo. Ciò significa che è definito dall'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, pertanto in futuro prevediamo di conservarne solo una
(#666).
results
è costituito da valori di payload che vengono visualizzati per primi e da un token che arriva per ultimo. In futuro, prevediamo di suddividere il payload e il token in due output distinti per maggiore chiarezza (#670).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
costante di tipo si64 |
|
(I3) | channel_type |
enum di DEVICE_TO_DEVICE e HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C2-C4) |
Vincoli
- (C1)
channel_type
è definito come:HOST_TO_DEVICE
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Esempi
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantica
Applica una funzione di riduzione body
a inputs
e init_values
lungo
dimensions
e produce results
tensori.
L'ordine delle riduzioni è definito dall'implementazione, il che significa che body
e
init_values
devono formare un monoide per garantire che l'operazione produca
gli stessi risultati per tutti gli input in tutte le implementazioni. Tuttavia, questa condizione
non vale per molte riduzioni comuni. Ad esempio, l'addizione a virgola mobile per body
e lo zero per init_values
non formano effettivamente un monoide perché l'addizione a virgola mobile non è associativa.
In modo più formale, results...[j0, ..., jR-1] = reduce(input_slices_converted)
dove:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, dove:
sono inseriti indimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
per un albero binarioschedule
dove:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
è un albero binario completo definito dall'implementazione il cui traversale in ordine consiste in:- Valori
input_slices_converted...[index]
per tutti i valoriindex
inindex_space(input_slices_converted)
nell'ordine lessicografico crescente diindex
. - Intercalati con una quantità di
init_values_converted
definita dall'implementazione in posizioni definite dall'implementazione.
- Valori
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C7) |
(I2) | init_values |
numero variadico di tensori di dimensione 0 o tensori quantizzati per tensore | (C2), (C3) |
(I3) | dimensions |
Costante tensoriale 1-dimensionale di tipo si64 |
(C4), (C5), (C7) |
(I4) | body |
funzione | (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C3), (C7), (C8) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
è di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, ad eccezione del fatto che le dimensioni diinputs...
corrispondenti adimensions
non sono incluse. - (C8)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantica
Esegue la conversione elemento per elemento di operand
in un altro tipo a virgola mobile
che utilizza exponent_bits
e mantissa_bits
e poi di nuovo al tipo a virgola mobile
originale e produce un tensore output
.
In termini più formali:
- I bit della frazione decimale del valore originale vengono aggiornati per arrotondare il valore originale al valore più vicino rappresentabile con
mantissa_bits
utilizzando la semanticaroundToIntegralTiesToEven
. - Se
mantissa_bits
è inferiore al numero di mantissa del valore originale, i bit vengono troncati amantissa_bits
. - Poi, se i bit dell'esponente del risultato intermedio non rientrano nell'intervallo fornito da
exponent_bits
, il risultato intermedio genera un overflow verso infinito utilizzando il segno originale o un underflow verso zero utilizzando il segno originale. - Per i tipi quantizzati, esegue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | exponent_bits |
costante di tipo si32 |
(C2) |
(I3) | mantissa_bits |
costante di tipo si32 |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Esempi
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantica
All'interno di ogni gruppo di processi nella griglia di processi StableHLO, esegue la riduzione,
utilizzando computations
, sui valori del tensore operand
di ogni processo,
suddivide il risultato della riduzione lungo scatter_dimension
in parti e disperde
le parti divise tra i processi per produrre result
.
L'operazione suddivide la griglia del processo StableHLO in process_groups
, che è
definita come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
Successivamente, all'interno di ogni process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
per tutti isender
inprocess_group
, dovereceiver_index = process_group.index(receiver)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
costante di tipo si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Costante tensoriale 2D di tipo si64 |
(C3-C5) |
(I4) | channel_id |
costante di tipo si64 |
(C6) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C6) |
(I6) | computation |
funzione | (C7) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C8-C9) |
Vincoli
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4) Per
size(replica_groups)
si intende:num_replicas
se viene utilizzatocross_replica
.num_replicas
se viene utilizzatocross_replica_and_partition
.num_processes
se viene utilizzatoflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C7)
computation
è di tipo(tensor<E>, tensor<E>) -> (tensor<E>)
doveis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
tranne:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantica
Applica una funzione di riduzione body
a finestre di inputs
e init_values
e produce results
.
Il seguente diagramma mostra come vengono calcolati gli elementi in results...
da
inputs...
utilizzando un esempio concreto.
In modo più formale,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(vedi reduce) dove:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
numero variadico di tensori di dimensione 0 o tensori quantizzati per tensore | (C1), (C13) |
(I3) | window_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Costante tensoriale 1-dimensionale di tipo si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Costante tensoriale 1-dimensionale di tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Costante tensoriale 2D di tipo si64 |
(C12), (C15) |
(I8) | body |
funzione | (C13) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C14-C16) |
Vincoli
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
è di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
dove:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semantica
Esegue il resto dei tensori di dividend lhs
e divisore rhs
e produce un tensore result
.
In modo più formale, il segno del risultato viene preso dal dividendo e il valore assoluto del risultato è sempre inferiore al valore assoluto del divisore.
Il resto viene calcolato come lhs - d * rhs
, dove d
è dato da:
- Per gli interi:
stablehlo.divide(lhs, rhs)
. - Per i valori float:
division(lhs, rhs)
da IEEE-754 con attributo di arrotondamentoroundTowardZero
. - Per i numeri complessi: da definire (#997).
- Per i tipi quantizzati:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Per i tipi di elementi con virgola mobile, questa operazione è in contrasto con
l'operazione remainder
della specifica IEEE-754, in cui d
è un valore integrale
più vicino al valore esatto di lhs/rhs
con legami a pari.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantica
Produce replica_id
del processo corrente.
Output
Nome | Tipo |
---|---|
result |
Tensore di dimensione 0 di tipo ui32 |
Esempi
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
rimodellare
Semantica
Esegue la trasformazione del tensore operand
in un tensore result
. A livello concettuale, equivale a mantenere la stessa rappresentazione canonica, ma potenzialmente a modificarne la forma, ad esempio da tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
In termini più formali, result[result_index] = operand[operand_index]
dove
result_index
e operand_index
hanno la stessa posizione nell'ordine alfabetico
di index_space(result)
e index_space(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C3) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
ad eccezione del fatto chequantization_dimension(operand)
equantization_dimension(result)
possono differire, altrimenti.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
inverti
Semantica
Inverte l'ordine degli elementi in operand
lungo dimensions
specificato
e produce un tensore result
. Più formalmente,
result[result_index] = operand[operand_index]
dove:
operand_index[d] = dim(result, d) - result_index[d] - 1
sed
indimensions
.operand_index[d] = result_index[d]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
(I2) | dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Esempi
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantica
Genera numeri casuali utilizzando l'algoritmo rng_distribution
e produce un tensore result
di una determinata forma shape
.
Se rng_distribution = UNIFORM
, i numeri casuali vengono generati
secondo la distribuzione uniforme nell'intervallo [a, b)
. Se a >= b
,
il comportamento non è definito.
Se rng_distribution = NORMAL
, i numeri casuali vengono generati in base alla distribuzione normale con media = a
e deviazione standard = b
.
Se b < 0
, il comportamento non è definito.
Il modo esatto in cui vengono generati i numeri casuali è definito dall'implementazione. Ad esempio, possono essere o meno deterministici e possono o meno utilizzare lo stato nascosto.
In conversazioni con molti stakeholder, è emerso che questa operazione è stata ritirata, pertanto in futuro prevediamo di valutarne la rimozione (#597).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
Tensore di dimensione 0 di tipo intero, booleano o con rappresentazione in virgola mobile | (C1), (C2) |
(I2) | b |
Tensore di dimensione 0 di tipo intero, booleano o con rappresentazione in virgola mobile | (C1), (C2) |
(I3) | shape |
Costante tensoriale 1-dimensionale di tipo si64 |
(C3) |
(I4) | rng_distribution |
enum di UNIFORM e NORMAL |
(C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, booleani o in virgola mobile | (C1-C3) |
Vincoli
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Se
rng_distribution = NORMAL
, allorais_float(a)
. - (C3)
shape(result) = shape
.
Esempi
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantica
Restituisce un valore output
riempito con bit casuali uniformi e uno stato di output aggiornato
output_state
utilizzando l'algoritmo del generatore di numeri pseudocasuali rng_algorithm
avendo uno stato iniziale initial_state
. L'output è garantito come funzione deterministica di initial_state
, ma non è garantito come deterministico tra le implementazioni.
rng_algorithm
è uno dei seguenti valori:
DEFAULT
: algoritmo definito dall'implementazione.THREE_FRY
: variante dell'algoritmo Threefry definita dall'implementazione.*PHILOX
: variante dell'algoritmo Philox definita dall'implementazione.*
* Vedi: Salmon et al. SC 2011. Numeri casuali paralleli: semplici come 1, 2, 3.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | rng_algorithm |
enum di DEFAULT , THREE_FRY e PHILOX |
(C2) |
(I2) | initial_state |
Tensore 1-dimensionale di tipo ui64 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output_state |
Tensore 1-dimensionale di tipo ui64 |
(C1) |
output |
tensore di tipo intero o con rappresentazione in virgola mobile |
Vincoli
- (C1)
type(initial_state) = type(output_state)
. - (C2) Per
size(initial_state)
si intende:- l'implementazione definita se
rng_algorithm = DEFAULT
. 2
serng_algorithm = THREE_FRY
.2
o3
serng_algorithm = PHILOX
.
- l'implementazione definita se
Esempi
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantica
Esegue l'arrotondamento elemento per elemento verso l'intero più vicino, rompendo i pareggi lontano da zero, sul tensore operand
e produce un tensore result
. Implementa
l'operazione roundToIntegralTiesToAway
dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantica
Esegue l'arrotondamento elemento per elemento verso l'intero più vicino, rompendo i pareggi
verso l'intero pari, sul tensore operand
e produce un tensore result
. Implementa l'operazione roundToIntegralTiesToEven
della specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantica
Esegue l'operazione di radice quadrata reciproca elemento per elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
rSqrt
dallo standard IEEE-754. - Per i numeri complessi: radice quadrata reciproca complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
spargere
Semantica
Produce tensori results
uguali ai tensori inputs
, tranne per il fatto che diversi slice specificati da scatter_indices
vengono aggiornati con i valori updates
utilizzando update_computation
.
Il seguente diagramma mostra come gli elementi in updates...
vengono mappati agli elementi in
results...
utilizzando un esempio concreto. Il diagramma seleziona alcuni esempi di indici updates...
e spiega in dettaglio a quali indici results...
corrispondono.
Più formalmente, per tutti i update_index
di index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
è definito come:scatter_indices[si0, ..., :, ..., siN]
, dovesi
sono singoli elementi inupdate_scatter_index
e:
viene inserito nell'indiceindex_vector_dim
, seindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
in caso contrario.
- Per
d_input
inaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
sed_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
in caso contrario.
- Per
d_input
inaxes(inputs[0])
:full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_input = input_batching_dims[i_batching]
ed_start = scatter_indices_batching_dims[i_batching]
.full_batching_index[d_input] = 0
in caso contrario.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, dovewi
sono singoli elementi diupdate_window_index
e0
viene inserito negli indici diinserted_window_dims
einput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Dato che, results = exec(schedule, inputs)
, dove:
schedule
è una permutazione definita dall'implementazione diindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
where:- Se
result_index
è compreso nei limiti pershape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
è una copia diresults
conresults...[result_index]
impostato suupdated_values...
.- In caso contrario
updated_results = results
.
- Se
exec([], results) = results
.
Se indices_are_sorted
è true
, l'implementazione può presumere che
scatter_indices
siano ordinati in base a scatter_dims_to_operand_dims
,
altrimenti il comportamento non è definito. Più formalmente, per tutti i i1 < i2
di
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Se unique_indices
è true
, l'implementazione può presumere che tutti gli indici result_index
a cui viene eseguita la distribuzione siano univoci. Se unique_indices
è
true
, ma gli indici a cui viene eseguita la distribuzione non sono univoci, il comportamento è
undefined.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensore di tipo intero | (C4), (C15), (C19), (C22) |
(I3) | updates |
numero variadico di tensori o tensori quantizzati per tensore | (C3-C6), (C8) |
(I4) | update_window_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Costante tensoriale 1-dimensionale di tipo si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Costante tensore monodimensionale di tipo si64 |
(C19-C21) |
(I9) | index_vector_dim |
costante di tipo si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
costante di tipo i1 |
|
(I11) | unique_indices |
costante di tipo i1 |
|
(I12) | update_computation |
funzione | (C23) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C24-C25) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2) "rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims))
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
dove:update_scatter_dim_sizes = shape(scatter_indices)
, ad eccezione del fatto che la dimensionescatter_indices
corrispondente aindex_vector_dim
non è inclusa.update_window_dim_sizes <= shape(inputs[0])
, ad eccezione del fatto che le dimensioni ininputs[0]
corrispondenti ainserted_window_dims
einput_batching_dims
non sono incluse.combine
inserisceupdate_scatter_dim_sizes
sugli assi corrispondenti aupdate_scatter_dims
eupdate_window_dim_sizes
sugli assi corrispondenti aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
è di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, doveis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
seleziona
Semantica
Produce un tensore result
in cui ogni elemento viene selezionato dal tensore on_true
o
on_false
in base al valore dell'elemento corrispondente di pred
.
In modo più formale, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, dove pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Per i tipi quantizzati, esegue
dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
tensore di tipo i1 |
(C1) |
(I2) | on_true |
tensore o tensore quantizzato per tensore | (C1-C2) |
(I3) | on_false |
tensore quantizzato per tensore o per tensore | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Esempi
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantica
Disperde i valori del tensore source
utilizzando scatter
in base al
risultato di reduce_window
del tensore input
utilizzando select
e produce
un tensore result
.
Il seguente diagramma mostra come vengono calcolati gli elementi in result
da
operand
e source
utilizzando un esempio concreto.
In termini più formali:
selected_values = reduce_window_without_init(...)
con i seguenti input:inputs = [operand].
window_dimensions
,window_strides
epadding
utilizzati così come sono.base_dilations = windows_dilations = 1
.body
è definito come:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
dove
E = element_type(operand)
ereduce_window_without_init
funzionano esattamente comereduce_window
, ad eccezione del fatto cheschedule
del valorereduce
sottostante (vedi Reduce) non include valori init. Al momento non è specificato cosa succede se la finestra corrispondente non contiene valori (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
dove:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
seselected_values[source_index]
contiene l'elementooperand
daoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensore quantizzato per tensore o per tensore | (C1), (C2) |
(I3) | init_value |
tensore 0-dimensionale o tensore quantizzato per tensore | (C3) |
(I4) | window_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C6), (C7) |
(I6) | padding |
Costante tensoriale 2D di tipo si64 |
(C2), (C8) |
(I7) | select |
funzione | (C9) |
(I8) | scatter |
funzione | (C10) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C11-C12) |
Vincoli
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
dove:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
è di tipo(tensor<E>, tensor<E>) -> tensor<i1>
doveE = element_type(operand)
. - (C10)
scatter
è di tipo(tensor<E>, tensor<E>) -> tensor<E>
doveis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Esempi
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
Invia
Semantica
Invia inputs
a un canale channel_id
e produce un token result
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati all'true
. Altrimenti, i dati vengono trasferiti su un altro dispositivo. Ciò significa che è definito dall'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, quindi in futuro prevediamo di conservarne solo una
(#666).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati | |
(I2) | token |
token |
|
(I3) | channel_id |
costante di tipo si64 |
|
(I4) | channel_type |
enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
token |
Vincoli
- (C1)
channel_type
è definito come:DEVICE_TO_HOST
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
Esempi
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantica
Esegue lo spostamento a sinistra dell'elemento sul tensore lhs
per un numero di rhs
bit e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantica
Esegue un'operazione aritmetica di spostamento a destra degli elementi sul tensore lhs
per
rhs
di bit e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantica
Esegue un'operazione di spostamento a destra logica elemento per elemento sul tensore lhs
per un numero di bit rhs
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
firmare
Semantica
Restituisce il segno dell'elemento operand
e produce un tensore result
.
In modo più formale, per ogni elemento x
, la semantica può essere espressa utilizzando la sintassi di Python come segue:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Per i tipi quantizzati, esegue
dequantize_op_quantize(sign, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero con segno, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero con segno, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semantica
Esegue un'operazione seno a livello di elemento sul tensore operand
e produce un tensore
result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
sin
da IEEE-754. - Per i numeri complessi: seno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(sine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantica
Estrae uno slice da operand
utilizzando gli indici di inizio calcolati in modo statico
e produce un tensore result
. start_indices
contengono gli indici iniziali del
sezionamento per ogni dimensione, limit_indices
contengono gli indici finali
(esclusivi) del sezionamento per ogni dimensione e strides
contengono gli incrementi
per ogni dimensione.
Più formalmente, result[result_index] = operand[operand_index]
dove
operand_index = start_indices + result_index * strides
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1-C3), (C5) |
(I2) | start_indices |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2), (C3), (C5) |
(I4) | strides |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Esempi
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
ordinare
Semantica
Ordina insieme gli slice 1D di inputs
lungo la dimensione dimension
,
in base a un comparator
e produce results
.
A differenza di input simili in altre operazioni, dimension
consente valori negativi, con la semantica descritta di seguito. In futuro, questa operazione potrebbe non essere consentita
per motivi di coerenza
(#1377).
Se is_stable
è true, l'ordinamento è stabile, ovvero l'ordine relativo degli elementi considerati uguali dal comparatore viene mantenuto. Nel caso in cui sia presente un singolo input, i due elementi e1
e e2
sono considerati uguali dal comparatore se e solo se comparator(e1, e2) = comparator(e2, e1) = false
. Consulta la formalizzazione di seguito per scoprire come si generalizza a più input.
In modo più formale, per tutti i valori result_index
in index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, doveriN
sono singoli elementi inresult_index
e:
viene inserito inadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- dove
sort
ordina uno slice unidimensionale in ordine non decrescente, prevedendo checomparator_together
restituiscatrue
se l'argomento a sinistra è minore del secondo argomento a destra. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C5) |
(I2) | dimension |
costante di tipo si64 |
(C4) |
(I3) | is_stable |
costante di tipo i1 |
|
(I4) | comparator |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C2), (C3) |
Vincoli
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, doveR = rank(inputs[0])
. - (C5)
comparator
è di tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, doveEi = element_type(inputs[i])
.
Esempi
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantica
Esegue l'operazione di radice quadrata elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
squareRoot
da IEEE-754. - Per i numeri complessi: radice quadrata complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(sqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantica
Esegue la sottrazione elemento per elemento di due tensori lhs
e rhs
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per gli interi: sottrazione di numeri interi.
- Per i valori float:
subtraction
da IEEE-754. - Per i numeri complessi: sottrazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantica
Esegue l'operazione tangente elemento per elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori float:
tan
da IEEE-754. - Per i numeri complessi: tangente complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tan, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantica
Esegue l'operazione di tangente iperbolica elemento per elemento sul tensore operand
e produce un tensore result
. A seconda del tipo di elemento:
- Per i valori float:
tanh
da IEEE-754. - Per i numeri complessi: tangente iperbolica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tanh, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
trasponi
Semantica
Permuta le dimensioni del tensore operand
utilizzando permutation
e produce un
result
tensore. In forma più formale, result[result_index] = operand[operand_index]
dove result_index[d] = operand_index[permutation[d]]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C4) |
(I2) | permutation |
Costante tensoriale 1-dimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3-C4) |
Vincoli
- (C1)
element_type(result)
è dato da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
, ad eccezione del fatto chequantization_dimension(operand)
equantization_dimension(result)
potrebbero essere diversi.
- (C2)
permutation
è una permutazione dirange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Se
is_per_axis_quantized(result)
, thenquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Esempi
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantica
Risolve batch di sistemi di equazioni lineari con matrici di coefficienti inferiori o superiori.
In modo più formale, dati a
e b
, result[i0, ..., iR-3, :, :]
è la soluzione
a op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
quando left_side
è
true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
quando
left_side
è false
, risolvendo per la variabile x
dove op(a)
è determinato
da transpose_a
, che può essere uno dei seguenti:
NO_TRANSPOSE
: esegui l'operazione utilizzandoa
così com'è.TRANSPOSE
: esegui l'operazione sulla matrice trasposta dia
.ADJOINT
: esegui l'operazione sulla trasposta coniugata dia
.
I dati di input vengono letti solo dal triangolo inferiore di a
, se lower
è true
o dal
triangolo superiore di a
, in caso contrario. I dati di output vengono restituiti nello stesso
triangolo; i valori nell'altro sono definiti dall'implementazione.
Se unit_diagonal
è vero, l'implementazione può assumere che gli elementi diagonali di a
siano uguali a 1, altrimenti il comportamento non è definito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | b |
tensore di un tensore con virgola mobile o complesso o di tipo quantizzato per tensore | (C1-C4) |
(I3) | left_side |
costante di tipo i1 |
(C3) |
(I4) | lower |
costante di tipo i1 |
|
(I5) | unit_diagonal |
costante di tipo i1 |
|
(I6) | transpose_a |
enum di NO_TRANSPOSE , TRANSPOSE e ADJOINT |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relazione tra
shape(a)
eshape(b)
è definita come segue:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Esempi
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tupla
Semantica
Genera una tupla result
dai valori val
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | val |
numero di valori variadico | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tupla | (C1) |
Vincoli
- (C1)
result
è di tipotuple<E0, ..., EN-1>
doveEi = type(val[i])
.
Esempi
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantica
Esegue la conversione elemento per elemento del tensore quantizzato operand
in un
tensore a virgola mobile result
in base ai parametri di quantizzazione definiti
dal tipo operand
.
In termini più formali, result = dequantize(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo floating-point | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Esempi
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantica
Esegue la conversione a livello di elementi del tensore a virgola mobile o del tensore quantizzato
operand
in un tensore quantizzato result
in base ai parametri
di quantizzazione definiti dal tipo result
.
In termini più formali,
- Se
is_float(operand)
:result = quantize(operand, type(result))
.
- Se
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo a virgola mobile o quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Esempi
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mentre
Semantica
Produce l'output dell'esecuzione della funzione body
0 o più volte mentre la funzione cond
restituisce true
. In modo più formale, la semantica può essere espressa utilizzando la sintassi di Python come segue:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Il comportamento di un loop infinito è da decidere (#383).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
numero variadico di tensori, tensori quantizzati o token | (C1-C3) |
(I2) | cond |
funzione | (C1) |
(I3) | body |
funzione | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori quantizzati o token | (C3) |
Vincoli
- (C1)
cond
ha il tipo(T0, ..., TN-1) -> tensor<i1>
, doveTi = type(operand[i])
. - (C2)
body
è di tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, doveTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Esempi
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantica
Esegue un XOR a livello di elemento di due tensori lhs
e rhs
e produce un tensore result
. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: XOR logico.
- Per gli interi: XOR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperabilità dei dialetti
Al momento, i programmi StableHLO in uso a volte contengono operazioni che non sono definite da StableHLO.
Modulo, funzione, chiamata e ritorno
StableHLO utilizza le operazioni MLIR a monte per ModuleOp, FuncOp, CallOp e ReturnOp. Questo è stato fatto per una migliore interoperabilità con la struttura MLIR esistente, poiché molti passaggi utili sono scritti per FuncOp e ModuleOp e molte pipeline di compilazione prevedono la presenza di queste operazioni. A queste operazioni vengono applicate garanzie di compatibilità completa. In caso di modifiche a queste operazioni in modo incompatibile (ovvero la rimozione), verranno aggiunti gli equivalenti StableHLO per preservare la compatibilità.
CHLO
L'opset CHLO contiene operazioni di livello superiore che si decompongono in StableHLO. Al momento non esistono garanzie di compatibilità per CHLO. Per garantire la compatibilità, il passaggio chl-legalize-to-stablehlo deve essere utilizzato prima della serializzazione.
Operazioni forma
È un caso d'uso comune nella community utilizzare determinate operazioni dei dialetti MLIR di base nei programmi StableHLO dinamici per eseguire calcoli delle forme.
In genere, sono incluse operazioni in dialetto shape
come shape_of
o num_elements
, operazioni in dialetto tensor
come dim
o from_elements
e il tipo index
integrato.
Dynamism RFC > O2
indica che questi elementi sono fuori ambito, tuttavia è incluso il supporto per i tipi index
per scopi di interoperabilità. Non sono previste garanzie di compatibilità per queste operazioni o questi tipi. Il passaggio shape-legalize-to-stablehlo
può essere utilizzato per convertire queste operazioni in operazioni StableHLO completamente supportate.
Operazioni ritirate
Esistono diverse operazioni StableHLO ereditate da MHLO che sono deprecate e in procinto di essere ritirate da StableHLO. I dettagli completi su queste rimozioni sono disponibili nella pagina relativa alla pulizia SttableHLO v1.0 n. 2283. Il problema del tracker per queste ritiri è #2340.
Queste operazioni rientrano in alcune categorie:
- Categoria "Non in HLO" delle operazioni StableHLO: inizialmente facevano parte del set di operazioni StableHLO, ma in seguito sono state ritenute non adatte:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Operazioni inutilizzate: queste operazioni potrebbero essere state utili a un certo punto, ma non sono state sviluppate o le pipeline che le utilizzano sono state ristrutturate in modo da non richiederle più. Sono inclusi i confronti
map
,tuple
(#598),get_tuple_element
,rng
,complex
#560 e la convezionewindow_reversal
(#1181).
Alcune di queste operazioni possono essere rimosse facilmente perché possono essere espresse utilizzando operazioni esistenti (broadcast
, create_token
, cross-replica-sum
, dot
,unary_einsum
) e verranno rimosse al termine del periodo di compatibilità esistente (6 mesi). Altre sono ancora in fase di esplorazione per la rimozione (einsum
,
get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
confronti, window_reversal
). In attesa del feedback della community,
queste operazioni verranno rimosse o aggiunte allo spec con il supporto completo. Finché
queste operazioni future non saranno note, sono garantiti solo 6 mesi di compatibilità.
Esecuzione
Esecuzione sequenziale
Un programma StableHLO viene eseguito fornendo valori di input alla funzione main
e calcolando i valori di output. I valori di output di una funzione vengono calcolati eseguendo il grafico delle operazioni rooted nell'operazione return
corrispondente.
L'ordine di esecuzione è definito dall'implementazione, a condizione che sia in linea con il flusso di dati, ovvero se le operazioni vengono eseguite prima del loro utilizzo. In StableHLO, tutte le operazioni con effetti collaterali consumano un token e producono un token (più token possono essere multiplexati in un token tramite after_all
), pertanto l'ordine di esecuzione degli effetti collaterali è in linea anche con il flusso di dati. Ad esempio, nel programma riportato di seguito
sono possibili due ordini di esecuzione: %0
→ %1
→ %2
→ return
e
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
In termini più formali, un processo StableHLO è una combinazione di:
1) un programma StableHLO, 2) stati di operazione (non ancora eseguiti,
già eseguiti) e 3) valori intermedi su cui il processo sta lavorando.
Il processo inizia con i valori di input della funzione main
, procede attraverso il grafico delle operazioni aggiornando gli stati delle operazioni e i valori intermedi e termina con i valori di output. Ulteriori dettagli sulla formalizzazione sono in fase di definizione
(#484).
Esecuzione parallela
I programmi StableHLO possono essere eseguiti in parallelo, organizzati in una griglia di processi 2D di num_replicas
per num_partitions
, entrambi di tipo ui32
.
Nella Griglia dei processi StableHLO, num_replicas * num_partitions
di processi StableHLO
sono in esecuzione contemporaneamente. Ogni processo ha un process_id = (replica_id, partition_id)
univoco, dove replica_id
in replica_ids = range(num_replicas)
e partition_id
in partition_ids = range(num_partitions)
hanno entrambi il tipo ui32
.
Le dimensioni della griglia di processi sono note in modo statico per ogni programma (in futuro prevediamo di renderla parte esplicita dei programmi StableHLO
#650) e la posizione
all'interno della griglia di processi è nota in modo statico per ogni processo. Ogni processo ha accesso alla propria posizione all'interno della griglia dei processi tramite le operazioni replica_id
e partition_id
.
All'interno della griglia del processo, i programmi possono essere tutti uguali (in stile "Programma unico, più dati"), possono essere tutti diversi (in stile "Programma più multiplo, più dati") o avere una via di mezzo. In futuro, prevediamo di introdurre il supporto per altri modi di definire programmi StableHLO paralleli, tra cui GSPMD (#619).
All'interno della griglia dei processi, i processi sono per lo più indipendenti tra loro: hanno stati delle operazioni separati, valori di input/intermedi/output separati e la maggior parte delle operazioni viene eseguita separatamente tra i processi, con l'eccezione di un piccolo numero di operazioni collettive descritte di seguito.
Poiché l'esecuzione della maggior parte delle operazioni utilizza solo i valori dello stesso
processo, di solito è inequivocabile fare riferimento a questi valori tramite i relativi nomi.
Tuttavia, quando si descrive la semantica delle operazioni collettive, questo non è sufficiente e dà origine alla notazione name@process_id
per fare riferimento al valore name
all'interno di un determinato processo. Da questo punto di vista, name
non qualificato può essere interpretato come una scorciatoia per name@(replica_id(), partition_id())
.
L'ordine di esecuzione tra i processi è definito dall'implementazione, ad eccezione della sincronizzazione introdotta dalla comunicazione point-to-point e dalle operazioni collettive come descritto di seguito.
Comunicazione punto a punto
I processi StableHLO possono comunicare tra loro tramite
canali StableHLO. Un canale è rappresentato da un ID positivo di tipo
si64
. Attraverso varie operazioni, è possibile inviare valori ai canali
e riceverli dai canali.
Ulteriori dettagli, ad esempio la provenienza di questi ID canale, la modalità di rilevamento da parte dei programmi di elaborazione e il tipo di sincronizzazione introdotto, sono in fase di definizione (#484).
Comunicazione in streaming
Ogni processo StableHLO ha accesso a due interfacce di streaming:
- Infeed che possono essere letti.
- Outfeed in cui è possibile scrivere.
A differenza dei canali, che vengono utilizzati per comunicare tra i processi e quindi hanno processi a entrambi i lati, gli infeed e gli outfeed hanno l'altro lato definito dall'implementazione.
Un'ulteriore formalizzazione, ad esempio il modo in cui la comunicazione in streaming influenza l'ordine di esecuzione e il tipo di sincronizzazione introdotta dalla stessa, è da definire (#484).
Operazioni collettive
In StableHLO sono disponibili sei operazioni collettive: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
e
reduce_scatter
. Tutte queste operazioni suddividono i processi nella griglia di processi StableHLO in gruppi di processi StableHLO ed eseguono un calcolo congiunto all'interno di ciascun gruppo di processi, indipendentemente dagli altri gruppi di processi.
All'interno di ciascun gruppo di processi, le operazioni collettive possono introdurre una barriera di sincronizzazione. Ulteriori formalizzazioni, ad esempio l'elaborazione di quando avviene esattamente questa sincronizzazione, di come esattamente i processi arrivano a questa barriera e di cosa succede se non lo fanno, sono in fase di definizione (#484).
Se il gruppo di processi prevede la comunicazione tra partizioni, ovvero se nel gruppo di processi sono presenti processi i cui ID partizione sono diversi, l'esecuzione dell'operazione collettiva richiede un canale e l'operazione collettiva deve fornire un channel_id
positivo di tipo si64
. La comunicazione tra repliche non richiede canali.
I calcoli eseguiti dalle operazioni collettive sono specifici per le singole operazioni e sono descritti nelle sezioni delle singole operazioni sopra. Tuttavia, le strategie con cui la griglia di processi è suddivisa in gruppi di processi sono condivise tra queste operazioni e sono descritte in questa sezione. In termini più formali, StableHLO supporta le seguenti quattro strategie.
cross_replica
Solo le comunicazioni con repliche diverse avvengono all'interno di ciascun gruppo di processi. Questa strategia prende replica_groups
, un elenco di elenchi di ID replica, e calcola un prodotto cartesiano di replica_groups
per partition_ids
. replica_groups
deve avere elementi univoci e coprire tutti i replica_ids
. Più formalmente, utilizzando la sintassi di Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica
produrrà
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
All'interno di ogni gruppo di processi avvengono solo comunicazioni tra partizioni. Questa strategia utilizza partition_groups
, un elenco di elenchi di ID partizione, e calcola un prodotto cartesiano partition_groups
in base a replica_ids
.
partition_groups
deve avere elementi univoci e coprire tutti i partition_ids
.
In modo più formale, utilizzando la sintassi di Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per partition_groups = [[0, 1]]
e num_replicas = 4
,
cross_partition
genererà
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Le comunicazioni cross-replica e cross-partition possono avvenire sia all'interno di ciascun gruppo di processi. Questa strategia prende replica_groups
, un elenco di elenchi di ID replica, e calcola i prodotti cartesiani di ogni replica_group
per partition_ids
. replica_groups
deve contenere elementi unici e coprire tutti i
replica_ids
. In modo più formale, utilizzando la sintassi di Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica_and_partition
produrrà
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Questa strategia prende flattened_id_groups
, un elenco di elenchi di ID processo "appiattiti" nel formato replica_id * num_partitions + partition_id
, e li trasforma in ID processo. flattened_id_groups
deve avere elementi univoci
e coprire tutti i process_ids
. In modo più formale, usando la sintassi Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
e num_partitions = 2
, flattened_ids
genererà
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Accuratezza
Al momento, StableHLO non fornisce garanzie sull'accuratezza numerica, ma la situazione potrebbe cambiare in futuro (#1156).
Semântica di esecuzione dell'operazione quantizzata
L'interpretazione delle operazioni quantizzate StableHLO può variare a seconda delle funzionalità e dei requisiti hardware. Ad esempio, alcuni hardware potrebbero scegliere di interpretare le operazioni quantizzate utilizzando una strategia di "dequantizzazione, esecuzione di un'operazione a virgola mobile e infine quantizzazione". Altri possono eseguire l'intero calcolo con aritmetica di numeri interi. Di conseguenza, l'interpretazione delle operazioni StableHLO quantizzate è determinata esclusivamente dall'implementazione specifica. L'interpretazione della quantizzazione ibrida (#1575) deve essere basata sulla sua semantica come prescritto nella specifica (tramite 1792).
Errori
I programmi StableHLO vengono convalidati tramite un ampio insieme di vincoli per le singole operazioni, che escludono molte classi di errori prima del tempo di esecuzione. Tuttavia, le condizioni di errore sono comunque possibili, ad esempio attraverso overflow di numeri interi, accessi fuori dai limiti e così via. Se non vengono richiamati esplicitamente, tutti questi errori comportano un comportamento definito dall'implementazione, che può però cambiare in futuro (#1157).
Eccezioni di virgola mobile
Come eccezione a questa regola, le eccezioni in virgola mobile nei programmi StableHLO hanno un comportamento ben definito. Le operazioni che generano eccezioni definite dallo standard IEEE-754 (operazione non valida, divisione per zero, overflow, underflow o eccezioni imprecise) producono risultati predefiniti (come definito nello standard) e continuano l'esecuzione senza attivare il flag di stato corrispondente; in modo simile alla gestione delle eccezioni raiseNoFlag
dello standard. Le eccezioni per le operazioni non standard (ad es. aritmetica complessa e determinate funzioni trascendentali) sono definite dall'implementazione.
Mancata corrispondenza delle forme
StableHLO supporta i tensori con forma dinamica. Tuttavia, le forme devono essere conformi al compilatore, altrimenti il comportamento non è definito. StableHLO non fornisce esplicitamente un'operazione che può asserire che un tensore abbia una determinata forma in fase di runtime. La generazione del codice corretto è responsabilità del produttore.
Come esempio specifico, il programma riportato di seguito è valido. Tuttavia, in fase di esecuzione, le forme esatte di %arg0
e %arg1
dovranno essere le stesse, altrimenti il comportamento del programma non è definito:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notazione
Per descrivere la sintassi, questo documento utilizza la sintassi EBNF con il tipo ISO modificato (ISO/IEC 14977:1996, Wikipedia), con due modifiche: 1) le regole sono definite utilizzando ::=
anziché =
,
2) la concatenazione viene espressa tramite giustapposizione anziché ,
.
Per descrivere la semantica (ad es. nelle sezioni "Tipi", "Costanti" e "Operazioni"), utilizziamo formule basate sulla sintassi di Python estesa con il supporto per esprimere in modo conciso le operazioni sugli array come descritto di seguito. Questo approccio funziona bene per piccoli snippet di codice, ma in rari casi in cui sono necessari snippet di codice più grandi, utilizziamo la sintassi di Python standard, che viene sempre introdotta esplicitamente.
Formule
Vediamo come funzionano le formule in base a un esempio della dot_general
specifica. Uno dei vincoli per questa operazione è il seguente:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
I nomi utilizzati in questa formula provengono da due origini: 1) funzioni globali,
ovvero dim
, 2) definizioni di membri dell'elemento di programma corrispondente, ovvero
input lhs
, lhs_batching_dimensions
, rhs
e rhs_batching_dimensions
definiti nella sezione "Input" di dot_general
.
Come accennato sopra, la sintassi di questa formula è basata su Python con alcune estensioni orientate alla concisione. Per comprendere la formula, trasformiamola in sintassi Python standard.
A) In queste formule utilizziamo =
per rappresentare l'uguaglianza, quindi il primo passo per ottenere la sintassi di Python è sostituire =
con ==
, come segue:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Inoltre, queste formule supportano i puntini di sospensione (...
) che trasformano le espressioni scalari in espressioni tensoriali. In breve, f(xs...)
significa approssimativamente "per ogni scalare x
nel tensore xs
, calcola un scalare f(x)
e poi restituisci tutti questi risultati scalari insieme come risultato del tensore". Nella sintassi Python standard,
la nostra formula di esempio diventa:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Grazie alle ellissi, spesso è possibile evitare di lavorare al livello dei singoli scalari. Tuttavia, in alcuni casi complicati, è possibile utilizzare una sintassi semi-informale di livello inferiore, come nella formula start_indices[bi0, ..., :, ..., biN]
della specifica gather
. Al servizio della concisione, non forniamo un formalismo esatto per tradurre tale sintassi in Python vanilla, nella speranza che sia comunque comprensibile in modo intuitivo caso per caso.
Facci sapere se alcune formule specifiche sembrano poco chiare e cercheremo di migliorarle.
Inoltre, noterai che le formule utilizzano i tre puntini per espandere tutti i tipi di elenchi, inclusi i tensori, gli elenchi di tensori (che ad esempio possono derivare da un numero variabile di tensori) e così via. Questa è un'altra area in cui non forniamo un formalismo esatto (ad esempio, gli elenchi non fanno nemmeno parte del sistema di tipi StableHLO) e ci basiamo invece sulla comprensibilità intuitiva.
C) L'ultimo mezzo di notazione degno di nota che utilizziamo è la trasmissione implicita. Sebbene l'opset StableHLO non supporti la trasmissione implicita, le formule supportano anche la concisione. In breve, se un valore scalare viene utilizzato in un contesto in cui è previsto un tensore, viene trasmesso alla forma prevista.
Per continuare con l'esempio dot_general
, ecco un altro vincolo:
0 <= lhs_batching_dimensions < rank(lhs)
. Come definito nella specifica dot_general
, lhs_batching_dimensions
è un tensore, ma sia 0
che
rank(lhs)
sono scalari. Dopo aver applicato la trasmissione implicita, la formula diventerà [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Se applicata a una determinata operazione dot_general
, questa formula avrà come risultato un tensore di valori booleani. Quando le formule vengono utilizzate come vincoli, il vincolo è valido se la formula restituisce true
o un tensore con solo elementi true
.
Nomi
Nelle formule, l'ambito lessicale include: 1) funzioni globali, 2) definizioni di membri,
3) Definizioni locali. Di seguito è riportato l'elenco delle funzioni globali. L'elenco delle definizioni degli elementi dipende dall'elemento del programma a cui viene applicata la notazione:
- Per le operazioni, le definizioni degli elementi includono i nomi introdotti nelle sezioni "Input" e "Output".
- Per tutto il resto, le definizioni dei membri includono parti strutturali dell'elemento del programma, denominate in base ai non terminali EBNF corrispondenti. Nella maggior parte dei casi, i nomi di queste parti strutturali vengono ottenuti convertendo i nomi dei non terminali in snake case (ad es.
IntegerLiteral
=>integer_literal
), ma a volte i nomi vengono abbreviati durante la procedura (ad es.QuantizationStorageType
=>storage_type
), nel qual caso i nomi vengono introdotti esplicitamente in modo simile alle sezioni "Input" / "Output" nelle specifiche di funzionamento. - Inoltre, le definizioni dei membri includono sempre
self
per fare riferimento all'elemento del programma corrispondente.
Valori
Quando le formule vengono valutate, funzionano con i seguenti tipi di valori:
1) Value
(valori effettivi, ad es. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
si conoscono sempre i relativi tipi),
2) Placeholder
(valori futuri, ad es. lhs
, rhs
o result
; i relativi
valori effettivi non sono ancora noti, sono noti solo i relativi tipi),
3) Type
(tipi come definiti nella sezione "Tipi"),
4) Function
(funzioni globali come definite nella sezione "Funzioni").
A seconda del contesto, i nomi possono fare riferimento a valori diversi. In modo più specifico, la sezione "Semantica" per le operazioni (e gli equivalenti per altri elementi del programma) definisce la logica di runtime, pertanto tutti gli input sono disponibili come Value
.
Al contrario, la sezione "Restrizioni" per le operazioni (e gli equivalenti) definisce la logica "in fase di compilazione", ovvero qualcosa che in genere viene eseguito prima del runtime, pertanto solo gli input costanti sono disponibili come Value
e gli altri input sono disponibili solo come Placeholder
.
Nomi | In "Semantica" | In "Vincoli" |
---|---|---|
Funzioni globali | Function |
Function |
Input costanti | Value |
Value |
Input non costanti | Value |
Placeholder |
Output | Value |
Placeholder |
Definizioni locali | Dipende dalla definizione | Dipende dalla definizione |
Consideriamo un'operazione transpose
di esempio:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Per questa operazione, permutation
è una costante, quindi è disponibile come Value
nella semantica e nei vincoli. Al contrario, operand
e result
sono disponibili come Value
nella semantica, ma solo come Placeholder
nei vincoli.
Funzioni
Costruzione di tipi
Non esistono funzioni che possono essere utilizzate per costruire tipi. Utilizziamo invece direttamente la sintassi del tipo perché in genere è più concisa. Ad es.
(tensor<E>, tensor<E>) -> (tensor<E>)
anziché function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funzioni sui tipi
element_type
è definito sui tipi di tensori e sui tipi di tensori quantizzati e restituisce, rispettivamente, la parteTensorElementType
oQuantizedTensorElementType
diTensorType
oQuantizedTensorType
corrispondente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
controlla se il tipox
può essere promosso al tipoy
. Quandox
ey
sonoQuantizedTensorElementType
, la promozione viene applicata solo alstorage_type
. Questa specifica versione della promozione viene attualmente utilizzata nel contesto del calcolo della riduzione (fai riferimento a RFC per ulteriori dettagli).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponibile per tutti i tipi. Ad esempio,is_float(x)
restituiscetrue
sex
è unFloatType
. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia peris_type_name(type(x))
.max_value(x: Type) -> Value
restituisce il valore massimo di unTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.min_value(x: Type) -> Value
restituisce il valore minimo possibile di unTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponibile per tutte le definizioni di membrimember_name
di tutti i tipi. Ad esempio,tensor_element_type(x)
restituisce la parteTensorElementType
di unTensorType
corrispondente. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia permember_name(type(x))
. Sex
non è un tipo con un membro appropriato, un valore o un segnaposto di questo tipo, restituisceNone
.is_empty_algorithm(*args: Type)
controlla se tutti i campi dell'algoritmo dei punti sono impostati suNone
. Questa operazione è necessaria perché gli algoritmi dei punti hanno comportamenti predefiniti definiti per l'implementazione, quindi specificare un valore predefinito sarebbe errato.
Costruzione di valori
operation_name(*xs: Value | Type) -> Value
. Disponibile per tutte le operazioni. Ad esempio,add(lhs, rhs)
prende due valori tensorelhs
erhs
e restituisce l'output della valutazione dell'operazioneadd
con questi input. Per alcune operazioni, ad esempiobroadcast_in_dim
, i tipi di output sono "portanti", ovvero necessari per valutare un'operazione. In questo caso, la funzione assume questi tipi come argomenti.
Funzioni sui valori
Sono disponibili tutti gli operatori e le funzioni di Python. Ad esempio, le notazioni di subscription e slicing di Python sono disponibili per indicizzare tensori, tensori quantizzati e tuple.
to_destination_type(x: Value, destination_type: Type) -> Value
è definito su tensori e restituisce il valore convertito dix
in base atype(x)
edestination_type
come segue:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Sono in corso discussioni iniziali sull'unione delle operazioni convert
, uniform_quantize
e
uniform_dequantize
(#1576).
Dopo l'unione non abbiamo bisogno della funzione sopra indicata e possiamo utilizzare il nome dell'operazione
per convert
.
is_nan(x: Value) -> Value
è definito sui tensori e restituiscetrue
se tutti gli elementi dix
sonoNaN
ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.is_sorted(x: Value) -> Value
è definito sui tensori e restituiscetrue
se gli elementi dix
sono ordinati in ordine crescente rispetto all'ordine alfabetico crescente dei relativi indici ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.is_unique(x: Value) -> Value
è definito sui tensori e restituiscetrue
sex
non contiene elementi duplicati ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.member_name(x: Value) -> Any
è definito per tutte le definizioni dei membrimember_name
di tutti i valori. Ad esempio,real_part(x)
restituisce la parteRealPart
di unComplexConstant
corrispondente. Sex
non è un valore con un membro appropriato, restituisceNone
.same(x: Value) -> Value
viene definito sui tensori e restituiscetrue
se gli elementi dix
sono tutti uguali tra loro ofalse
in caso contrario. Se il tensore non ha elementi, viene considerato "tutti uguali tra loro", ovvero la funzione restituiscetrue
. Sex
non è un tensore, restituisceNone
.split(x: Value, num_results: Value, axis: Value) -> Value
è definito su tensori e restituiscenum_results
sezioni dix
lungo l'asseaxis
. Sex
non è un tensore odim(x, axis) % num_results != 0
, restituisceNone
.is_defined_in_parent_scope(x: Value) -> Value
è definito sulle stringhe e restituiscetrue
sex
è il nome di una funzione definita nello stesso ambito della funzione principale dell'operazione pertinente.is_namespaced_op_name(x: Value) -> Value
viene definito sulle stringhe e restituiscetrue
sex
è un nome dell'operazione valido, ovvero rispetta la seguente espressione regolare:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Calcoli delle forme
axes(x: Value | Placeholder | Type) -> Value
è una scorciatoia perrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
è una scorciatoia pershape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
è una scorciatoia perlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
è definito sui tensori e restituisce gli indicisize(x)
per ilTensorType
corrispondente ordinato in ordine lessicografico crescente, ovvero[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Sex
non è un tipo di tensore, un tipo di tensore quantizzato, un valore o un segnaposto di uno di questi tipi, restituisceNone
.rank(x: Value | Placeholder | Type) -> Value
è una scorciatoia persize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
è definito nella sezione "Funzioni su tipi" tramitemember_name
.size(x: Value | Placeholder | Type) -> Value
è una scorciatoia perreduce(lambda x, y: x * y, shape(x))
.
Calcoli di quantizzazione
def baseline_element_type(x: Value | Placeholder | Type) -> Type
è una scorciatoia perelement_type(baseline_type(x))
.baseline_type
è definito sui tipi di tensore e sui tipi di tensore quantizzati e li trasforma in un "valore di riferimento", ovvero un tipo con la stessa forma, ma con i parametri di quantizzazione del tipo di elemento reimpostati sui valori predefiniti. Questo viene utilizzato come un pratico trucco per confrontare in modo uniforme i tipi di tensori e di tensori quantizzati, il che è necessario molto spesso. Per i tipi quantizzati, questa opzione consente di confrontare i tipi ignorando i parametri di quantizzazione, ovveroshape
,storage_type
,expressed_type
,storage_min
,storage_max
equantization_dimension
(per il tipo quantizzato per asse) devono corrispondere, mascales
ezero points
potrebbero differire.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
è definito sui tipi di tensori quantizzati e li trasforma in tipi di tensori in virgola mobile. Ciò avviene mediante la conversione degli elementi quantizzati che rappresentano i valori interi del tipo di archiviazione in valori corrispondente a virgola mobile del tipo espresso utilizzando il punto zero e la scala associato al tipo di elemento quantizzato.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
è definito sui tipi di tensore a virgola mobile e li trasforma in tipi di tensore quantizzati. Ciò avviene mediante la conversione dei valori in virgola mobile del tipo espresso in valori interi corrispondenti del tipo di archiviazione utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
viene utilizzato per specificare i calcoli elemento per elemento sui tensori quantizzati. Dequantizza, ovvero trasforma gli elementi quantizzati nei relativi tipi espressi, poi esegue un'operazione e infine quantizza, ovvero trasforma nuovamente i risultati nei relativi tipi di archiviazione. Al momento, questa funzione funziona solo per la quantizzazione per tensore. La quantizzazione per asse è in fase di sviluppo (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
viene utilizzato per specificare la quantizzazione solo per i pesi per operazioni ibride che accettano lhs in virgola mobile e rhs in tipi quantizzati. Dequantizza gli input quantizzati nei tipi espressi ed esegue il calcolo in virgola mobile. Il tipo di elemento del tensore lhs float e il tipo espresso del tensore rhs quantizzato devono essere identici.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calcoli della griglia
cross_partition(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica" sopra.cross_replica(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica" sopra.cross_replica_and_partition(replica_groups: Value) -> Value
. Consulta la sezione "cross_replica_and_partition" sopra.flattened_ids(replica_groups: Value) -> Value
. Consulta la sezione "flattened_ids" sopra.
Dinamismo
I valori HLO stabili possono avere dimensioni di dimensioni dinamiche, ad esempio tensor<?xi64>
.
Tuttavia, i valori StableHLO non possono avere un numero dinamico di dimensioni (dinamismo non classificato, ad es. tensor<*xi64>
). Gli operandi e i risultati possono utilizzare dimensioni dinamiche, anche se esistono vincoli per le dimensioni. I vincoli verranno verificati staticamente, se possibile, altrimenti vengono posticipati al runtime e le mancate corrispondenze comporteranno un comportamento non definito. Di seguito sono riportati gli esempi.
Mancata corrispondenza della forma per le operazioni una tantum degli elementi
Considera il seguente programma di esempio:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un programma di questo tipo è insolito, perché non è comune conoscere la forma del
risultato ma non la forma dell'input. Tuttavia, si tratta di un programma
StableHLO valido. Non è possibile convalidare in modo statico l'operazione abs
in questo
programma perché la forma esatta dell'operando è sconosciuta. Tuttavia, le forme
sono sicuramente compatibili e questo può essere controllato in modo statico: ?
potrebbe risultare essere 2
in fase di esecuzione e non ci sarebbero problemi. Tuttavia, ?
potrebbe anche essere un altro numero intero, nel qual caso il comportamento non è definito.
Tieni presente che se le dimensioni di una dimensione sono dinamiche nel risultato, questo comportamento non può essere indefinito. Infatti, non esiste una dimensione "attesa", quindi non può esserci una mancata corrispondenza.
Mancata corrispondenza della forma per le operazioni elementari binarie
Considera il seguente programma di esempio:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Per quanto riguarda le operazioni binarie a livello di elemento, le forme degli input e il risultato devono corrispondere in fase di runtime. In fase di compilazione, le dimensioni statiche devono essere uguali, altrimenti devono semplicemente essere compatibili. Se qualsiasi dimensione è dinamica negli input, potrebbe esserci un comportamento indefinito in fase di runtime, perché la dimensione dinamica potrebbe non corrispondere a quella corrispondente nell'altro operando (statico o dinamico). Se tutti gli input sono statici, il fatto che il risultato sia dinamico o meno non ha importanza: le dimensioni staticamente note verranno controllate in modo statico e le dimensioni dinamiche non impongono vincoli.
Mancata corrispondenza della forma per le operazioni che prendono la forma dell'output come operando
Considera il seguente programma di esempio:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
I valori nell'operando della forma in fase di esecuzione devono corrispondere alla forma del risultato,
altrimenti il comportamento non è definito. In altre parole, in fase di esecuzione %arg0
deve avere un valore dense<[3, 4]> : tensor<2xi32>
. Se l'operando di forma è costante, può essere verificato in modo statico. Se la forma del risultato è completamente dinamica, non può esserci una mancata corrispondenza.