StableHLO è un set di operazioni per operazioni ad alto livello (HLO) in macchina di machine learning (ML). StabileHLO funziona come livello di portabilità tra diverse Framework ML e compilatori ML: framework ML che producono programmi StableHLO sono compatibili con i compilatori ML che utilizzano i programmi StableHLO.
Il nostro obiettivo è semplificare e accelerare lo sviluppo ML creando più all'interoperabilità tra vari framework ML (come TensorFlow, JAX e PyTorch) e compilatori ML (come XLA e IREE). Per raggiungere questo obiettivo, fornisce una specifica per il linguaggio di programmazione StableHLO.
Questa specifica contiene tre sezioni principali. In primo luogo, La sezione Programmi descrive la struttura dei programmi StableHLO che consistono in funzioni StableHLO, a loro volta costituite da operazioni StableHLO. All'interno di questa struttura, la sezione Ops specifica la semantica del singole operazioni. La sezione Esecuzione fornisce la semantica per tutti a queste operazioni eseguite insieme all'interno di un programma. Infine, La sezione Notazione illustra la notazione utilizzata in questa sezione. e la specifica del prodotto.
Per visualizzare le specifiche di una release precedente di StableHLO, apri il repository nella uscita taggata di tuo interesse. Ad esempio, la Specifica SttableHLO v0.19.0. Per visualizzare le modifiche che si sono verificate in corrispondenza di ogni bumper della versione secondaria di StableHLO, consulta il log della versione in VhloDialect.td.
Programmi
Program ::= {Func}
I programmi StableHLO sono costituiti da un numero arbitrario di funzioni StableHLO.
Di seguito è riportato un esempio di programma con una funzione @main
che ha 3 input
(%image
, %weights
e %bias
) e 1 output. Il corpo della funzione
ha 6 operazioni.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funzioni
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Le funzioni stabileHLO (chiamate anche funzioni con nome) hanno un identificatore, input/output e un corpo. In futuro, prevediamo di Introdurre metadati aggiuntivi per le funzioni al fine di ottenere una migliore compatibilità con HLO (#425, N. 626, #740, #744).
Identificatori
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Gli identificatori StabileHLO sono simili agli identificatori in molti programmi di lingue diverse, con due peculiarità: 1) tutti gli identificatori hanno sigilli che distinguere i diversi tipi di identificatori, 2) gli identificatori dei valori possono essere completamente numerici per semplificare la generazione di programmi StableHLO.
Tipi
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
I tipi SttableHLO vengono classificati in tipi di valore (chiamati anche tipi di prima classe) che rappresentano valori StableHLO e tipi non di valore che descrivono altri elementi del programma. I tipi HLO stabili sono simili a quelli in molti linguaggi di programmazione, la cui peculiarità principale è il software una natura specifica del dominio, il che porta ad alcuni risultati insoliti (ad es. tipi scalari non sono tipi di valori).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
I tipi di Tensor rappresentano i tensori, ovvero gli array multidimensionali. Hanno un
forma e un tipo di elemento, dove una forma rappresenta un numero non negativo o
dimensioni delle dimensioni sconosciute in ordine crescente delle dimensioni corrispondenti
dimensioni (chiamate anche assi) numerate da 0
a R-1
. La
numero di dimensioni R
è chiamato ranking. Ad esempio, tensor<2x3xf32>
è
un tipo di tensore con forma 2x3
e tipo di elemento f32
. Ha due dimensioni
(o, in altre parole, due assi): la 0° dimensione e la 1° dimensione, le cui dimensioni
sono 2 e 3. Il suo ranking è 2.
Le forme possono essere parzialmente o completamente sconosciute (dinamiche), ad esempio tensor<?x2xf64>
è parzialmente sconosciuto e tensor<?x?xf64>
è completamente sconosciuto. Dinamico
le dimensioni sono rappresentate utilizzando un ?
. Le forme non possono essere rimosse dal ranking.
In futuro, prevediamo di esplorare l'estensione dei tipi di tensori oltre dimensioni di dimensione e tipi di elementi, ad esempio per includere layout (#629) e sparsità (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nome | Tipo | Vincoli |
---|---|---|
storage_type |
tipo di numero intero | (C1-C3), (C8) |
storage_min |
costante intera | (C1), (C3), (C7) |
storage_max |
costante intera | (C2), (C3), (C7) |
expressed_type |
tipo con virgola mobile | (C4) |
quantization_dimension |
costante intera facoltativa | (C10-C12) |
scales |
numero variadico di costanti in virgola mobile | (C4-C6), (C9), (C10), (C13) |
zero_points |
numero variadico di costanti intere | (C7-C9) |
I tipi di elementi quantizzati rappresentano i valori interi di un tipo di archiviazione in
l'intervallo da storage_min
a storage_max
(incluso) che corrisponde
valori in virgola mobile di un tipo espresso. Per un determinato valore intero i
,
il valore in virgola mobile corrispondente f
può essere calcolato
f = (i - zero_point) * scale
, dove scale
e zero_point
sono chiamati
parametri di quantizzazione. storage_min
e storage_max
sono facoltativi
nella grammatica, ma presentano valori predefiniti di min_value(storage_type)
e
max_value(storage_type)
rispettivamente. I tipi di elementi quantizzati hanno
seguenti vincoli:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Se
is_empty(quantization_dimension)
, allorasize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Al momento, QuantizationScale
è una costante in virgola mobile, ma c'è
Forte interesse per le scale basate su numeri interi, rappresentate con moltiplicatori e
senza interruzioni. Abbiamo in programma di esplorare questa funzionalità nel prossimo futuro
(N. 1404).
È in corso una discussione sulla semantica di QuantizationZeroPoint
,
tra cui il tipo, i valori e se possono esserci solo uno
potenzialmente più zeri in un tipo di tensore quantizzato. In base alla
risultati di questa discussione, la specifica relativa allo zero punti può cambiare
in futuro (#1405).
Un'altra discussione in corso riguarda la semantica di QuantizationStorageMin
e QuantizationStorageMax
per determinare se i vincoli devono essere
imposti a questi valori e a quelli dei tensori quantizzati
(N. 1406).
Infine, abbiamo in programma di esplorare la rappresentazione delle scale sconosciute e zero analogamente a come stiamo pianificando di esplorare la rappresentazione (#1407).
I tipi di tensori quantizzati rappresentano i tensori con elementi quantizzati. Questi tensori sono esattamente gli stessi dei tensori normali, con l'eccezione che i loro elementi hanno tipi di elementi quantizzati, invece dei normali tipi di elementi.
Nei tensori quantizzati, la quantizzazione può essere per tensore, ovvero:
uno scale
e zero_point
per l'intero tensore oppure può essere per asse,
ovvero avere più scales
e zero_points
, una coppia per sezione di
una determinata dimensione quantization_dimension
. Più formalmente, in un tensore t
con la quantizzazione per asse, sono presenti dim(t, quantization_dimension)
sezioni
di quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
,
e così via. Tutti gli elementi nella i
a sezione usano scales[i]
e zero_points[i]
come
i relativi parametri di quantizzazione. I tipi di tensori quantizzati hanno quanto segue
vincoli:
- Per la quantizzazione per tensore:
- Nessun vincolo aggiuntivo.
- Per la quantizzazione per asse:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
I tipi di token rappresentano i token, ovvero valori opaci prodotti e consumati. da alcune operazioni. I token vengono utilizzati per imporre l'ordine di esecuzione sulle operazioni come descritto nella sezione Esecuzione.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
I tipi di tuple rappresentano tuple, ovvero elenchi eterogenei. Le tuple sono un'eredità
che esiste solo per compatibilità con l'HLO. In HLO, le tuple sono
utilizzato per rappresentare input e output variadi. In StableHLO, gli input variadi
sono supportati in modo nativo e l'unico utilizzo di tuple in StableHLO è
rappresentano in modo esaustivo un'ABI HLO, ad esempio T
, tuple<T>
e
tuple<tuple<T>>
può essere significativamente differente a seconda di un particolare
implementazione. In futuro, abbiamo in programma di apportare modifiche ad HLO ABI
che può consentirci di rimuovere i tipi di tuple da StableHLO
(n. 598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
I tipi di elemento rappresentano gli elementi dei tipi tensoriali. A differenza di molti programmi
lingue diverse, non sono all'avanguardia in StableHLO. Ciò significa che
I programmi StableHLO non possono rappresentare direttamente questi valori (di conseguenza,
è idiomatico rappresentare valori scalari di tipo T
con tensore 0-dimensionale
di tipo tensor<T>
).
- Il tipo booleano rappresenta i valori booleani
true
efalse
. - I tipi di numeri interi possono essere firmati (
si
) o non firmati (ui
) e hanno una delle larghezze in bit supportate (2
,4
,8
,16
,32
o64
). I tipisiN
firmati rappresentano valori interi compresi tra-2^(N-1)
e2^(N-1)-1
I tipi inclusivi euiN
senza segno rappresentano valori interi compresi tra0
e2^N-1
inclusi. - I tipi con rappresentazione in virgola mobile possono essere:
- .
- Tipi
f8E4M3FN
ef8E5M2
corrispondenti rispettivamente al CodificheE4M3
eE5M2
del formato FP8 descritto in Formati dell'FP88 per il deep learning. - Tipi
f8E4M3FNUZ
ef8E5M2FNUZ
corrispondenti aE4M3
eE5M2
delle codifiche dei formati FP8 descritti in Formati numerici a 8 bit per reti neurali profonde. - Tipo
f8E4M3B11FNUZ
corrispondente alla codificaE4M3
dei formati FP8 descritti in Addestramento e inferenza in virgola mobile ibrida a 8 bit (HFP8) per reti neurali profonde. - Tipo
bf16
corrispondente al formatobfloat16
descritto in BFloat16: The secret to high performance on Cloud TPUs. - Tipi
f16
,f32
ef64
corrispondenti rispettivamentebinary16
("precisione dimezzata"),binary32
("precisione singola") ebinary64
("doppia precisione") descritti in standard IEEE 754. - Il tipo
tf32
corrisponde al formato TensorFloat32 e ha un supporto limitato in StableHLO.
- Tipi
- I tipi complessi rappresentano valori complessi che hanno una parte reale.
e una parte immaginaria dello stesso tipo di elemento. Complesso supportato
i tipi sono
complex<f32>
(entrambe le parti sono di tipof32
) ecomplex<f64>
(entrambe le parti sono di tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
I tipi di funzioni rappresentano funzioni sia con nome che anonime. Hanno
tipi di input (l'elenco dei tipi a sinistra di ->
) e tipi di output
(l'elenco dei tipi disponibile sul lato destro di ->
). In molti casi di programmazione
linguaggi di programmazione, i tipi di funzioni sono di prima classe, ma non in StableHLO.
StringType ::= 'string'
Il tipo di stringa rappresenta sequenze di byte. A differenza di molti programmi lingue, il tipo di stringa non è di prima classe in StableHLO e viene utilizzato solo specificare metadati statici per gli elementi del programma.
Operazioni
Le operazioni SttableHLO (chiamate anche operazioni) rappresentano un insieme chiuso di operazioni ad alto livello nei modelli di machine learning. Come già detto, La sintassi stabileHLO è fortemente ispirata da MLIR, che non è necessariamente il modello ma è probabilmente la soluzione più adatta per l'obiettivo di StableHLO, creando una maggiore interoperabilità tra framework ML e compilatori ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Le operazioni SttableHLO (chiamate anche ops) hanno un nome,
input/output e una firma. Il nome è composto dal prefisso stablehlo.
e
un mnemonico che identifica in modo univoco una delle operazioni supportate. Vedi di seguito per
un elenco completo di tutte le operazioni supportate.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Le operazioni consumano input e producono output. Gli input sono classificati in
valori di input (calcolati durante l'esecuzione), funzioni di input (fornite
in modo statico, perché in StableHLO le funzioni non sono valori all'avanguardia) e
attributi di input (forniti anche in modo statico). Il tipo di input e output
consumato e prodotto da un'operazione dipende dal suo mnemonico. Ad esempio, add
L'operazione consuma 2 valori di input e produce 1 valore di output. In confronto,
L'operazione select_and_scatter
utilizza 3 valori di input, 2 funzioni di input e
3 attributi di input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Le funzioni di input (chiamate anche funzioni anonime) sono molto
simili alle funzioni con nome, tranne per il fatto che: 1) non hanno un identificatore (quindi
il nome "anonimo"), 2) non dichiarano tipi di output (i tipi di output sono
dedotto dall'operazione return
all'interno della funzione).
La sintassi delle funzioni di input include una parte attualmente inutilizzata (vedi il
produzione Unused
riportata sopra) per garantire la compatibilità con MLIR. Nell'MLIR,
esiste un concetto più generale di "regioni" che può avere più "blocchi"
operazioni collegate tra loro tramite jump ops. Questi blocchi hanno ID che corrispondono
alla produzione Unused
, in modo che possano essere distinti tra loro.
StableHLO non ha operazioni di salto, quindi la parte corrispondente della sintassi MLIR è
non utilizzato (ma è ancora presente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Gli attributi di input contengono un nome e un valore che è uno dei supportati
costanti. Sono il modo principale per specificare metadati statici per il programma,
elementi. Ad esempio, l'operazione concatenate
utilizza l'attributo dimension
per
specificare la dimensione lungo la quale i suoi valori di input sono concatenati. Analogamente,
l'operazione slice
utilizza più attributi, come start_indices
e limit_indices
per specificare i limiti utilizzati per suddividere il valore di input.
Al momento, i programmi StableHLO in circolazione a volte contengono attributi non descritti in questo documento. In futuro, prevediamo di assorbono questi attributi nell'opset StableHLO o li impediscano dei programmi StableHLO. Nel frattempo, ecco l'elenco di questi attributi:
layout
(n. 629).mhlo.frontend_attributes
(#628).mhlo.sharding
(n. 619).output_operand_aliases
(#740).- Metadati sulla località (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma dell'operazione consiste nei tipi di tutti i valori di input (l'elenco dei tipi disponibili
sul lato sinistro di ->
) e i tipi di tutti i valori di output (l'elenco
digita sul lato destro di ->
). In parole povere, i tipi di input sono
e i tipi di output sono quasi sempre ridondanti (perché
nella maggior parte delle operazioni StableHLO, i tipi di output possono essere dedotti dagli input). Tuttavia, op
la firma è volutamente parte della sintassi StableHLO per la compatibilità con MLIR.
Di seguito è riportata un'operazione di esempio il cui mnemonico è select_and_scatter
. Ne consuma 3
valori di input (%operand
, %source
e %init_value
), 2 funzioni di input
e tre attributi di input (window_dimensions
, window_strides
e padding
).
Nota come la firma dell'operazione include solo i tipi dei suoi valori di input
(ma non i tipi di funzioni e attributi di input forniti in linea).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Costanti
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Le costanti SttableHLO hanno un valore letterale e un tipo che insieme rappresentano
un valore StableHLO. In genere, il tipo fa parte della sintassi costante, tranne
quando è non ambigua (ad es. una costante booleana ha un tipo inequivocabile i1
),
mentre una costante intera può avere più tipi possibili).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Le costanti booleane rappresentano i valori booleani true
e false
. Valore booleano
costanti sono di tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Le costanti interi rappresentano valori interi tramite stringhe che utilizzano valori decimali o una notazione esadecimale. Altre basi, ad esempio binari o ottali, non sono supportati. Le costanti intere hanno i seguenti vincoli:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Le costanti con virgola mobile rappresentano valori in virgola mobile tramite stringhe che utilizzare la notazione decimale o scientifica. Inoltre, la notazione esadecimale può essere specifica direttamente i bit sottostanti nel formato a virgola mobile di il tipo corrispondente. Le costanti con virgola mobile hanno i seguenti vincoli:
- (C1) Se viene utilizzata una notazione non esadecimale,
is_wellformed(float_literal, float_type)
. - (C2) Se viene utilizzata la notazione esadecimale,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Le costanti complesse rappresentano valori complessi mediante elenchi di una parte reale
(arriva prima) e una parte immaginaria (arriva seconda). Ad esempio:
(1.0, 0.0) : complex<f32>
rappresenta 1.0 + 0.0i
e
(0.0, 1.0) : complex<f32>
rappresenta 0.0 + 1.0i
. L'ordine in cui
vengono archiviate nella memoria, è definito dall'implementazione. Costanti complesse
presentano i seguenti vincoli:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Le costanti di Tensor rappresentano i valori dei tensori utilizzando elenchi nidificati specificati tramite
Notazione NumPy. Ad esempio, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
rappresenta un valore tensore con la seguente mappatura dagli indici agli elementi:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. L'ordine in cui questi elementi vengono poi archiviati in memoria
dell'implementazione. Le costanti tensore hanno i seguenti vincoli:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, in cui:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, dove:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- altrimenti
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Le costanti tensoriali quantizzate rappresentano i valori dei tensori quantizzati utilizzando lo stesso la notazione come costanti tensoriali, con gli elementi specificati come costanti tipo di archiviazione. Le costanti tensoriali quantizzate hanno i seguenti vincoli:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
I valori letterali stringa sono costituiti da byte specificati utilizzando caratteri ASCII e
sequenze di escape. Sono indipendenti dalla codifica, quindi l'interpretazione di questi
è definito dall'implementazione. I valori letterali stringa hanno il tipo string
.
Operazioni
abs
Semantica
Esegue un'operazione ABS a livello di elemento sul tensore operand
e produce un result
tensore. A seconda del tipo di elemento:
- Per i numeri interi firmati: modulo intero.
- Per i numeri in virgola mobile:
abs
dallo standard IEEE-754. - Per i numeri complessi: modulo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(abs, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numero intero firmato, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero con segno, rappresentazione in virgola mobile o tensore quantizzato per tensore | (C1-C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2) Per
baseline_element_type(result)
si intende:complex_element_type(element_type(operand))
seis_complex(operand)
.baseline_element_type(operand)
in caso contrario.
Esempi
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semantica
Esegue l'aggiunta di due tensori lhs
e rhs
a livello di elemento e produce un
tensore result
. A seconda del tipo di elemento:
- Per i valori booleani: OR logico.
- Per i numeri interi: addizione di numeri interi.
- Per i numeri in virgola mobile:
addition
dallo standard IEEE-754. - Per i numeri complessi: addizione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore o tensore quantizzato | (C1-C6) |
(I2) | rhs |
tensore o tensore quantizzato | (C1-C5), (C7) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C7) |
Vincoli
- Se l'operazione utilizza tensori non quantizzati:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Se l'operazione utilizza tensori quantizzati:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Se
is_per_axis_quantized(lhs)
, alloraquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Se
is_per_axis_quantized(rhs)
, alloraquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantica
Garantisce che le operazioni che producono inputs
vengano eseguite prima che
che dipendono da result
. L'esecuzione di questa operazione non ha alcun effetto,
ma esiste solo per stabilire le dipendenze dati da result
a inputs
.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di token |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantica
All'interno di ogni gruppo di processi nella griglia del processo StableHLO, concatena i valori
dei operands
tensori di ogni processo lungo all_gather_dim
e produce
results
tensori.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
In seguito, all'interno di ogni process_group
:
operands...@receiver = [operand@sender for sender in process_group]
per tuttireceiver
aprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
per tuttiprocess
aprocess_group
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C6) |
(I2) | all_gather_dim |
costante di tipo si64 |
(C1), (C6) |
(I3) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C2-C4) |
(I4) | channel_id |
costante di tipo si64 |
(C5) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C6) |
Vincoli
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
è definito come:num_replicas
se si utilizzacross_replica
.num_replicas
se si utilizzacross_replica_and_partition
.num_processes
se si utilizzaflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C6)
type(results...) = type(operands...)
eccetto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantica
All'interno di ciascun gruppo di processi nella griglia dei processi StableHLO, applica una riduzione
funzione computation
ai valori dei tensori operands
di ciascun processo
e produce results
tensori.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
In seguito, all'interno di ogni process_group
:
results...@process[result_index] = exec(schedule)
per un albero binarioschedule
dove:- .
exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
è un albero binario definito dall'implementazione, il cui ordine attraversamento èto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C5), (C6) |
(I2) | replica_groups |
numero variadico delle costanti tensoriali unidimensionali di tipo si64 |
(C1-C3) |
(I3) | channel_id |
costante di tipo si64 |
(C4) |
(I4) | use_global_device_ids |
costante di tipo i1 |
(C4) |
(I5) | computation |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C6-C7) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2) Per
size(replica_groups)
si intende:num_replicas
se si utilizzacross_replica
.num_replicas
se si utilizzacross_replica_and_partition
.num_processes
se si utilizzaflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C5)
computation
ha il tipo(tensor<E>, tensor<E>) -> (tensor<E>)
doveis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantica
All'interno di ogni gruppo di processi nella griglia del processo StableHLO, suddivide i valori di
i tensori operands
lungo split_dimension
in parti, disperde la suddivisione
parti tra i processi, concatena le parti sparse
concat_dimension
e produce results
tensori.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
In seguito, all'interno di ogni process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
per tutti isender
inprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
dovereceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operands |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C3), (C9) |
(I2) | split_dimension |
costante di tipo si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
costante di tipo si64 |
(C3), (C9) |
(I4) | split_count |
costante di tipo si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C5-C8) |
(I6) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C9) |
Vincoli
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
è definito come:num_replicas
se si utilizzacross_replica
.num_partitions
se si utilizzacross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
tranne, sesplit_dimension != concat_dimension
:- .
dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
e
Semantica
Esegue l'operatore AND a livello di elemento di due tensori lhs
e rhs
e produce un result
tensore. A seconda del tipo di elemento:
- Per i valori booleani: AND logico.
- Per i numeri interi: AND a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di un tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantica
Esegue l'operazione atan2 a livello di elementi sul tensore lhs
e rhs
e produce un
tensore result
. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
atan2
dallo standard IEEE-754. - Per i numeri complessi: atan2 complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantica
Calcola i gradienti di diversi input della retropropagazione di batch_norm_training
da grad_output
e produce grad_operand
, grad_scale
e grad_offset
tensori. Più formalmente, questa operazione può essere espressa come una decomposizione
le operazioni StableHLO esistenti utilizzando la sintassi Python come segue:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C3), (C5) |
(I2) | scale |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4), (C5) |
(I3) | mean |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I4) | variance |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I5) | grad_output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
grad_operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
grad_scale |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
grad_offset |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
egrad_offset
hanno lo stessobaseline_element_type
. - (C3)
operand
,grad_output
egrad_operand
hanno la stessa forma. - (C4)
scale
,mean
,variance
,grad_scale
egrad_offset
hanno i la stessa forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantica
Normalizza il tensore operand
in tutte le dimensioni ad eccezione di
feature_index
e produce un tensore result
. Più formalmente, questo
l'operazione può essere espressa come decomposizione in operazioni StableHLO esistenti
utilizzando la sintassi Python come segue:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C7) |
(I2) | scale |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C3) |
(I3) | offset |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C4) |
(I4) | mean |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C5) |
(I5) | variance |
Tensore monodimensionale di tipo quantizzato in virgola mobile o per tensore | (C2), (C6) |
(I6) | epsilon |
costante di tipo f32 |
|
(I7) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C7) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
eresult
hanno i stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantica
Calcola la media e la varianza su tutte le dimensioni tranne feature_index
e normalizza il tensore operand
producendo output
, batch_mean
e batch_var
tensori. Più formalmente, questa operazione può essere espressa come
la decomposizione in operazioni StableHLO esistenti utilizzando la sintassi Python
che segue:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | scale |
Tensione monodimensionale di un tensore in virgola mobile o per tensore quantizzato | (C2), (C3) |
(I3) | offset |
Tensione monodimensionale di un tensore in virgola mobile o per tensore quantizzato | (C2), (C4) |
(I4) | epsilon |
costante di tipo f32 |
(C1), (C3-C6) |
(I5) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C7) |
batch_mean |
Tensione monodimensionale di un tensore in virgola mobile o per tensore quantizzato | (C2), (C5) |
batch_var |
Tensione monodimensionale di un tensore in virgola mobile o per tensore quantizzato | (C2), (C6) |
Vincoli
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
eoutput
hanno lo stessobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantica
Esegue un'operazione bitcast sul tensore operand
e produce un tensore result
dove i bit dell'intero tensore operand
vengono reinterpretati utilizzando
tipo del tensore result
.
Più formalmente, sulla base di E = element_type(operand)
, E' = element_type(result)
,
e R = rank(operand)
:
- Se
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Se
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Se
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
restituisce la rappresentazione in memoria di un determinato valore e il relativo comportamento
è definita dall'implementazione perché l'esatta rappresentazione dei tensori è
dell'implementazione, e l'esatta rappresentazione dei tipi di elementi è
dell'implementazione.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C2) |
Vincoli
- (C1) Dati
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
eR = rank(operand)
:- Se
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Se
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Se
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
per tutti i0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Se
- (C2) Se
is_complex(operand) or is_complex(result)
:is_complex(operand) and is_complex(result)
.
Esempi
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantica
Espande le dimensioni e/o il ranking di un tensore di input duplicando i dati
nel tensore operand
e produce un tensore result
. Più formalmente,
result[result_index] = operand[operand_index]
dove per tutti i d
in
axes(operand)
:
operand_index[d] = 0
sedim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2-C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C6) |
Vincoli
- (C1) Il valore
element_type(result)
è fornito da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
,scales(operand)
ezero_points(operand)
potrebbero differire daquantization_dimension(result)
,scales(result)
ezero_points(result)
o in caso contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Per tutti i valori
d
inaxes(operand)
:- .
dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se
dim(operand, quantization_dimension(operand)) = 1
, allorascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Esempi
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
richiesta
Semantica
Genera l'output dall'esecuzione esattamente di una funzione da branches
a seconda del valore di index
. Più formalmente, result = selected_branch()
dove:
selected_branch = branches[index]
se0 <= index < size(branches)
.selected_branch = branches[-1]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | index |
Tensore 0dimensionale di tipo si32 |
|
(I2) | branches |
numero variadico di funzioni | (C1-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori o token quantizzati | (C4) |
Vincoli
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Esempi
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
Semantica
Esegue un'operazione di radice cubica a livello di elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
rootn(x, 3)
dallo standard IEEE-754. - Per i numeri complessi: radice cubica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(cbrt, operand, type(result))
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantica
Esegue il controllo degli elementi del tensore operand
e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardPositive
da IEEE-754
e la specifica del prodotto. Per i tipi quantizzati, esegue
dequantize_op_quantize(ceil, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantica
Calcola la decomposizione di Cholesky di un batch di matrici.
Più formalmente, per tutti i i
di index_space(result)
,
result[i0, ..., iR-3, :, :]
è una decomposizione di Cholesky di
a[i0, ..., iR-3, :, :]
, sotto forma di triangolare inferiore
(se lower
è true
) o triangolare superiore (se lower
è false
).
I valori di output nel triangolo opposto, ad esempio il triangolo superiore o
triangolo inferiore rigorosamente inferiore sono definiti dall'implementazione.
Se esiste i
dove la matrice di input non è una definizione positiva Hermitiana
, il comportamento è indefinito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | lower |
Costante tensore 0-dimensionale di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Esempi
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
clampare
Semantica
Fissa ogni elemento del tensore operand
tra un minimo e un massimo
e produce un tensore result
. Più formalmente, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
dove min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Per i tipi quantizzati,
esegue dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Imporre un ordinamento sui numeri complessi implica una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | min |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
(I2) | operand |
tensore quantizzato per tensore o per tensore | (C1-C4) |
(I3) | max |
tensore quantizzato per tensore o per tensore | (C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C4) |
Vincoli
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantica
All'interno di ciascun gruppo di processi nella griglia dei processi StableHLO, invia il valore del parametro
operand
dal processo di origine ai processi di destinazione e produce una
result
tensore.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(replica_groups)
sechannel_id <= 0
.cross_partition(replica_groups)
sechannel_id > 0
.
In seguito, result@process
viene assegnato da:
operand@process_groups[i, 0]
se esiste uni
tale che il processo sia aprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
negli altri casi.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C3) |
(I2) | replica_groups |
numero variadico delle costanti tensoriali unidimensionali di tipo si64 |
(C1), (C2) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C3) |
Vincoli
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
doveN
è definito come:- .
num_replicas
se si utilizzacross_replica
.num_partitions
se si utilizzacross_partition
.
- (C3)
type(result) = type(operand)
.
Esempi
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantica
All'interno di ciascun gruppo di processi nella griglia dei processi StableHLO, invia il valore del parametro
tensore operand
dal processo di origine a quello target e produce un
result
tensore.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(source_target_pairs)
sechannel_id <= 0
.cross_partition(source_target_pairs)
sechannel_id > 0
.
In seguito, result@process
viene assegnato da:
operand@process_groups[i, 0]
, se esiste uni
tale cheprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
negli altri casi.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C5) |
(I2) | source_target_pairs |
Costante tensore bidimensionale di tipo si64 |
(C1-C4) |
(I3) | channel_id |
costante di tipo si64 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1) |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, doveN
è definito come:- .
num_replicas
se si utilizzacross_replica
.num_partitions
se si utilizzacross_partition
.
- (C5)
type(result) = type(operand)
.
Esempi
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
confronta
Semantica
Esegue un confronto a livello di elementi dei tensori lhs
e rhs
in base a
comparison_direction
e compare_type
e produce un tensore result
.
I valori di comparison_direction
e compare_type
hanno i seguenti valori
semantica:
Per i tipi di elementi booleani e interi:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Per i tipi di elementi con virgola mobile con compare_type = FLOAT
, l'operazione implementa
le seguenti operazioni IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Per i tipi di elementi con rappresentazione in virgola mobile con compare_type = TOTALORDER
, l'operazione
utilizza la combinazione di operazioni totalOrder
e compareQuietEqual
IEEE 754.
Per i tipi di elementi complessi, il confronto grammaticale delle coppie (real, imag)
è
eseguita utilizzando i comparison_direction
e i compare_type
forniti.
Imporre un ordinamento sui numeri complessi implica una semantica sorprendente,
quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi
quando comparison_direction
è GE
, GT
, LE
o LT
(n. 560).
Per i tipi quantizzati. esegue dequantize_compare(lhs, rhs,
comparison_direction)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1-C3) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1-C2) |
(I3) | comparison_direction |
enum di EQ , NE , GE , GT , LE e LT |
|
(I4) | compare_type |
enum di FLOAT , TOTALORDER , SIGNED e UNSIGNED |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo booleano | (C2) |
Vincoli
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
è definito come:SIGNED
seis_signed_integer(element_type(lhs))
.UNSIGNED
seis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
seis_float(element_type(lhs))
.FLOAT
seis_complex(element_type(lhs))
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complesso
Semantica
Esegue la conversione a livello di elemento in un valore complesso da una coppia di valori reali e
valori immaginari, lhs
e rhs
, e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo f32 o f64 |
(C1-C3) |
(I2) | rhs |
tensore di tipo f32 o f64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo complesso | (C2), (C3) |
Vincoli
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
ha il tipocomplex<E>
doveE = element_type(lhs)
.
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
composito
Semantica
Incapsula un'operazione composta (composta) da altre operazioni StableHLO,
prendendo inputs
e composite_attributes
e producendo results
. La
la semantica dell'operazione è implementata dall'attributo decomposition
. La
L'operazione composite
può essere sostituita con la sua decomposizione senza modificare il programma
la semantica. Nei casi in cui la decomposizione incorporata non fornisca lo stesso
semantica dell'operazione, è preferibile usare custom_call
.
Il campo version
(il valore predefinito è 0
) viene utilizzato per indicare quando un elemento
modifica della semantica.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di valori |
(I2) | name |
costante di tipo string |
(I3) | composite_attributes |
dizionario attributi |
(I4) | decomposition |
costante di tipo string |
(I5) | version |
costante di tipo si32 |
Output
Nome | Tipo |
---|---|
results |
numero variadico di valori |
Vincoli
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Esempi
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenare
Semantica
Concatena inputs
lungo la dimensione dimension
nello stesso ordine della dimensione specificata
argomenti e produce un tensore result
. Più formalmente,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, dove:
id = d0 + ... + dk-1 + kd
.d
è uguale adimension
ed0
, ... sono led
a dimensioni delle dimensioni diinputs
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C6) |
(I2) | dimension |
costante di tipo si64 |
(C2), (C4), (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C5-C6) |
Vincoli
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
trannedim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
eccetto:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Esempi
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
costante
Semantica
Produce un tensore output
da una costante value
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | value |
costante | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore o tensore quantizzato | (C1) |
Vincoli
- (C1)
type(value) = type(output)
.
Esempi
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
convertire
Semantica
Esegue una conversione a livello di elemento da un tipo di elemento a un altro.
operand
e produce un tensore result
.
Per le conversioni da boolean-to-any-supported-type, il valore false
è
convertito in zero e il valore true
viene convertito in uno. Per
conversioni any-supported-type-to-boolean, viene convertito un valore pari a zero
false
e i valori diversi da zero vengono convertiti in true
. Scopri di seguito come fare
per i tipi complessi.
Per le conversioni che implicano da intero a numero intero, da intero a virgola mobile o floating-point-to-floating-point, se il valore dell'origine può essere esattamente rappresentato nel tipo di destinazione, il valore del risultato corrisponde una rappresentazione visiva. In caso contrario, il comportamento è da definire (N. 180).
Per le conversioni che coinvolgono floating-point-to-integer, la parte frazionaria è troncato. Se il valore troncato non può essere rappresentato nel tipo di destinazione, Il comportamento è da definire (#180).
Le conversioni da da complessa a complessa seguono lo stesso comportamento di Conversioni floating-point-to-floating-point per convertire conversioni reali e parti immaginarie.
Per le conversioni di tipo complex-to-any-other-type e da complex-to-any-other-type, il valore immaginario di origine viene ignorato o il valore immaginario di destinazione è rispettivamente pari a zero. La conversione della parte reale segue conversioni con rappresentazione in virgola mobile.
In linea di principio, questa operazione potrebbe esprimere la dequantizzazione (conversione da
tensori quantizzati in tensori regolari), la quantizzazione (conversione da
tensori a tensori quantizzati) e la requantizzazione (conversione tra
tensori), ma al momento abbiamo operazioni dedicate -
uniform_dequantize
per il primo caso d'uso e uniform_quantize
per il
secondo e terzo casi d'uso. In futuro, queste due operazioni potrebbero essere unite
in convert
(#1576).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore | (C1) |
Vincoli
- (C1)
shape(operand) = shape(result)
.
Esempi
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convoluzione
Semantica
Calcola i prodotti scalare tra le finestre di lhs
e le sezioni di rhs
e produce
result
. Il seguente diagramma mostra come vengono calcolati gli elementi in result
a partire da
lhs
e rhs
utilizzando un esempio concreto.
Più formalmente, considera la seguente riformulazione degli input in termini di lhs
Per poter esprimere finestre di lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Questa riinquadratura utilizza le seguenti funzioni helper:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
dovej[d] = i[permutation[d]]
.
Se feature_group_count = 1
e batch_group_count = 1
, allora per tutte
output_spatial_index
a index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
dove:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Questa funzione sembra non essere utilizzata, pertanto in futuro abbiamo intenzione di rimuoverla (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Se feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Se batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C2-C3), (C25) |
(I4) | padding |
Costante tensore bidimensionale di tipo si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Costante tensore monodimensionale di tipo i1 |
(C9) |
(I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
costante di tipo si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
numero variadico di enumerazioni di DEFAULT , HIGH e HIGHEST |
(C24) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C25-C28), (C30), (C32-34) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) In base a
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) In base a
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) In base a
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25) Per
dim(result, result_dim)
si intende:dim(lhs, input_batch_dimension) / batch_group_count
seresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
seresult_dim = output_feature_dimension
.num_windows
altrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, poiquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
:is_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Esempi
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semantica
Esegue un'operazione coseno a livello di elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
cos
dallo standard IEEE-754. - Per i numeri complessi: coseno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(cosine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantica
Esegue il conteggio a livello di elemento del numero di zero bit iniziali in operand
tensore e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantica
Incapsula un'operazione definita dall'implementazione call_target_name
che prende
inputs
e called_computations
e produce results
. has_side_effect
,
È possibile utilizzare backend_config
e api_version
per fornire ulteriori
e metadati definiti dall'implementazione.
Al momento, questa operazione contiene una raccolta di dati che riflettono l'evoluzione organica dell'operazione di controparte in il compilatore XLA. In futuro, prevediamo di unificare questi metadati (n. 741).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di valori |
(I2) | call_target_name |
costante di tipo string |
(I3) | has_side_effect |
costante di tipo i1 |
(I4) | backend_config |
costante di tipo string o dizionario degli attributi |
(I5) | api_version |
costante di tipo si32 |
(I6) | called_computations |
numero variadico di costanti di tipo string |
Output
Nome | Tipo |
---|---|
results |
numero variadico di valori |
Esempi
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
divisione
Semantica
Esegue la divisione in base agli elementi dei tensori del dividendo lhs
e del divisore rhs
.
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri interi: divisione di numeri interi che produce il quoziente algebrico con qualsiasi parte frazionaria ignorata.
- Per i numeri in virgola mobile:
division
dallo standard IEEE-754. - Per i numeri complessi: divisione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di numeri interi, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantica
Calcola i prodotti scalare tra le sezioni di lhs
e le sezioni di rhs
e produce una
tensore result
.
Più formalmente, result[result_index] = dot_product
, dove:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dovesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
esize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
controlla il compromesso tra velocità e accuratezza per
sui backend degli acceleratori. Può trattarsi di uno dei seguenti valori (nel
momento, la semantica di questi valori enum è sottospecificata, ma
abbiamo intenzione di affrontare questo
#755):
DEFAULT
: il calcolo è più veloce, ma l'approssimazione meno accurata al numero originale.HIGH
: calcolo più lento, ma approssimazione più precisa al numero originale.HIGHEST
: calcolo più lento, ma approssimazione più precisa al numero originale.
Un DotAlgorithm
definisce le proprietà principali dell'algoritmo utilizzato per l'implementazione
l'operazione., che definisce anche la precisione. Se l'attributo dell'algoritmo
, il valore di precision_config
deve essere DEFAULT
. DotAlgorithms
non hanno un valore predefinito, in quanto i parametri predefiniti sono
definito. Di conseguenza, tutti i campi dell'algoritmo dei punti possono essere impostati su None
per specificare un
algoritmo empty. che utilizzerà invece il valore precision_config
.
I campi di DotAlgorithm
includono:
lhs_precision_type
erhs_precision_type
, le precisione utilizzate da LHS e A destra dell'operazione viene arrotondato. I tipi di precisione sono indipendenti tipi di archiviazione di input e output.accumulation_type
la precisione utilizzata per l'accumulo.lhs_component_count
,rhs_component_count
enum_primitive_operations
si applicano quando si parla di un algoritmo che scompone il lato destro e/o sinistro più componenti e svolge più elementi "primitivi" le operazioni "dot" , generalmente per emulare una precisione più alta (ad es. Sfruttamento del tipo di dati dell'intelligenza artificiale bfloat16 per calcoli di precisione: bf16_6x tf32_3x, ecc). Per gli algoritmi senza decomposizione, questi valori deve essere impostato su1
.allow_imprecise_accumulation
per specificare se l'accumulo è con una precisione inferiore è consentito per alcuni passaggi (ad es.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Esempi di attributi DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Spetta alle implementazioni decidere quali combinazioni sono supportate. Nella in generale, non è garantito che ogni algoritmo sia supportato su ogni tipo di acceleratore da parte del consumatore di StableHLO. Se un determinato algoritmo non viene supportato, occorre segnalare un errore anziché utilizzare alternativa. La verifica stabile HLO fornirà la verifica del miglior sforzo, in modo da impedire algoritmi che non sono supportati su qualsiasi hardware.
Vedi xla_data.proto > Algorithm
per alcuni valori degli algoritmi supportati. Il ticket 2483 illustra il piano per creare un
documento centralizzato sugli algoritmi supportati dal backend.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensore o tensore quantizzato | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
numero variadico di enumerazioni di DEFAULT , HIGH e HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType o TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType o TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType o TensorFloat32 | (C21) |
(I11) | lhs_component_count |
costante di tipo si32 |
(C21), (C22) |
(I12) | rhs_component_count |
costante di tipo si32 |
(C21), (C23) |
(I13) | num_primitive_operations |
costante di tipo si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
costante di tipo bool |
(C21) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C12), (C14), (C18-C20) |
Vincoli
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Se l'operazione utilizza tensori non quantizzati:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Se l'operazione utilizza tensori quantizzati:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Se
is_per_axis_quantized(rhs)
:quantization_dimension(rhs)
non inrhs_contracting_dimensions
. - Se
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Se
is_per_tensor_quantized(rhs)
:is_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Se
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Esempi
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantica
Questa operazione è funzionalmente identica a
broadcast_in_dim
ma la forma del risultato viene specificata in modo dinamico tramite output_dimensions
.
L'operazione accetta anche gli attributi facoltativi known_expanding_dimensions
, known_non_expanding_dimensions
per esprimere una conoscenza statica del comportamento di espansione delle dimensioni.
Se non specificato, si presume che tutte le dimensioni possano espandersi.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensore monodimensionale di tipo intero | (C7) |
(I3) | broadcast_dimensions |
Tensore costante monodimensionale di tipo intero | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensore costante monodimensionale di tipo intero | (C8-C9) |
(I5) | known_non_expanding_dimensions |
Tensore costante monodimensionale di tipo intero | (C8-C9) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C7) |
Vincoli
- (C1) Il valore
element_type(result)
è fornito da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
,scales(operand)
ezero_points(operand)
potrebbero differire daquantization_dimension(result)
,scales(result)
ezero_points(result)
o in caso contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Per tutti i valori
d
inaxes(operand)
:- .
dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Se
dim(operand, quantization_dimension(operand)) = 1
, allorascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_non_expanding_dimensions < rank(operand)
.
Esempi
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantica
Questa operazione è funzionalmente identica a
convoluzione
ma la spaziatura interna viene specificata in modo dinamico tramite padding
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensore bidimensionale di tipo intero | (C4) |
(I4) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C2-C3) |
(I5) | lhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C5-C6) |
(I6) | rhs_dilation |
Costante tensore monodimensionale di tipo si64 |
(C7-C8) |
(I7) | window_reversal |
Costante tensore monodimensionale di tipo i1 |
(C9) |
(I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C17-C18) |
(I14) | output_batch_dimension |
costante di tipo si64 |
(C20) |
(I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C19-C20) |
(I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
numero variadico di enumerazioni di DEFAULT , HIGH e HIGHEST |
(C24) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C25-C27), (C29), (C31-C33) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) In base a
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) In base a
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) In base a
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25) Per
dim(result, result_dim)
si intende:dim(lhs, input_batch_dimension) / batch_group_count
seresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
seresult_dim = output_feature_dimension
.num_windows
altrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Se
is_per_axis_quantized(rhs)
, poiquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Se
is_per_axis_quantized(result)
:quantization_dimension(result) = output_feature_dimension
. - Se
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Se
is_per_tensor_quantized(rhs)
:is_per_tensor_quantized(result)
. - Se
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Esempi
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantica
Questa operazione è funzionalmente identica a
raccogliere
, con il valore slice_sizes
specificato in modo dinamico come valore.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensore di tipo intero | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensore monodimensionale di tipo intero | (C8), (C11-C13) |
(I4) | offset_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Costante tensore monodimensionale di tipo si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
costante di tipo si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
costante di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C5), (C13-C14) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dove:batch_dim_sizes = shape(start_indices)
ad eccezione del fatto che le dimensioni distart_indices
corrispondente aindex_vector_dim
non è incluso.offset_dim_sizes = shape(slice_sizes)
, tranne per il fatto che le dimensioni inslice_sizes
corrispondenti acollapsed_slice_dims
non sono inclusi.combine
posizionabatch_dim_sizes
sugli assi corrispondenti abatch_dims
eoffset_dim_sizes
in assi corrispondenti aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Esempi
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantica
Questa operazione è funzionalmente identica a
iota
ma la forma del risultato viene specificata in modo dinamico tramite output_shape
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | output_shape |
Tensore monodimensionale di tipo intero | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Esempi
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantica
Questa operazione è funzionalmente identica a
pad
op, ma con edge_padding_low
, edge_padding_high
e interior_padding
specificate dinamicamente come valori.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C2), (C4) |
(I2) | padding_value |
tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
(I3) | edge_padding_low |
Tensore monodimensionale di tipo intero | (C1), (C4) |
(I4) | edge_padding_high |
Tensore monodimensionale di tipo intero | (C1), (C4) |
(I5) | interior_padding |
Tensore monodimensionale di tipo intero | (C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantica
Questa operazione è funzionalmente identica a
rimodella
ma la forma del risultato viene specificata in modo dinamico tramite output_shape
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
(I2) | output_shape |
Tensore monodimensionale di tipo intero | (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C4) |
Vincoli
- (C1) Il valore
element_type(result)
è fornito da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
equantization_dimension(result)
potrebbe variare.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantica
Estrae una sezione dalla tabella operand
utilizzando indici iniziali calcolati in modo dinamico
e produce un tensore result
. start_indices
contengono gli indici iniziali di
la sezione per ogni dimensione soggetta a potenziali aggiustamenti e slice_sizes
contengono le dimensioni della sezione per ogni dimensione. Più formalmente,
result[result_index] = operand[operand_index]
dove:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C2), (C4) |
(I2) | start_indices |
numero variadico dei tensori 0-dimensionali di tipo intero | (C2), (C3) |
(I3) | slice_sizes |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Esempi
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantica
Produce un tensore result
uguale al tensore operand
tranne che
la sezione che parte da start_indices
viene aggiornata con i valori in update
.
Più formalmente, result[result_index]
è definito come:
update[update_index]
se0 <= update_index < shape(update)
dove:- .
adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1-C4), (C6) |
(I2) | update |
tensore quantizzato per tensore o per tensore | (C2), (C3), (C6) |
(I3) | start_indices |
numero variadico dei tensori 0-dimensionali di tipo intero | (C4), (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Esempi
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
esponenziale
Semantica
Esegue un'operazione esponenziale a livello di elementi sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
exp
dallo standard IEEE-754. - Per i numeri complessi: esponenziale complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantica
Esegue un'operazione esponenziale a livello di elemento meno un'operazione sul tensore operand
e
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
expm1
dallo standard IEEE-754. - Per i numeri complessi: esponenziale complesso meno uno.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Semantica
Esegue le trasformazioni di Fourier avanti e inverse per ottenere dati complessi come input/output.
fft_type
corrisponde a uno dei seguenti:
FFT
: inoltro di FFT da complessi a complessi.IFFT
: FFT da complesso a complesso inverso.RFFT
: inoltro FFT da reale a complesso.IRFFT
: FFT inverso da reale a complesso (ovvero richiede un'operazione complessa, restituisce un risultato reale).
Più formalmente, data la funzione fft
che prende i tensori monodimensionali di
complessi come input, produce tensori monodimensionali degli stessi tipi di
e calcola la trasformata discreta di Fourier:
Per fft_type = FFT
, result
è definito come il risultato finale di una serie di L
in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Inoltre, data la funzione ifft
che ha lo stesso tipo di firma e
calcola l'inversa di fft
:
Per fft_type = IFFT
, result
è definito come l'inverso dei calcoli
per fft_type = FFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Inoltre, data la funzione rfft
che prende i tensori monodimensionali di
tipi a virgola mobile, produce tensori monodimensionali di tipi complessi
la stessa semantica con rappresentazione in virgola mobile e funziona come segue:
rfft(real_operand) = truncated_result
dovecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Quando la trasformata discreta di Fourier viene calcolata per gli operandi reali, la prima
Gli elementi N/2 + 1
del risultato definiscono in modo inequivocabile il resto del risultato,
quindi il risultato di rfft
viene troncato per evitare il calcolo di elementi ridondanti).
Per fft_type = RFFT
, result
è definito come il risultato finale di una serie di L
in cui L = size(fft_length)
. Ad esempio, per L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Infine, data la funzione irfft
che ha la stessa firma di tipo e
calcola l'inversa di rfft
:
Per fft_type = IRFFT
, result
è definito come l'inverso dei calcoli
per fft_type = RFFT
. Ad esempio, per L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di rappresentazione in virgola mobile o di tipo complesso | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum di FFT , IFFT , RFFT e IRFFT |
(C2), (C5) |
(I3) | fft_length |
Costante tensore monodimensionale di tipo si64 |
(C1), (C3), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di rappresentazione in virgola mobile o di tipo complesso | (C2), (C4), (C5) |
Vincoli
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relazione tra i tipi di elementi
operand
eresult
varia:- .
- Se
fft_type = FFT
,element_type(operand)
eelement_type(result)
dello stesso tipo complesso. - Se
fft_type = IFFT
,element_type(operand)
eelement_type(result)
dello stesso tipo complesso. - Se
fft_type = RFFT
,element_type(operand)
è un tipo con rappresentazione in virgola mobile eelement_type(result)
è un tipo complesso dello stesso numero in virgola mobile la semantica. - Se
fft_type = IRFFT
,element_type(operand)
è un tipo complesso eelement_type(result)
è un tipo con rappresentazione in virgola mobile dello stesso tipo di rappresentazione in virgola mobile la semantica.
- Se
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Se tra
operand
eresult
, c'è un tensorereal
di un con rappresentazione in virgola mobile, poishape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
eccetto:- Se
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Se
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Se
Esempi
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
Semantica
Esegue il pavimento a livello di elemento del tensore operand
e produce un tensore result
.
Implementa l'operazione roundToIntegralTowardNegative
da IEEE-754
e la specifica del prodotto. Per i tipi quantizzati, esegue
dequantize_op_quantize(floor, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
raccogliere
Semantica
Raccoglie le sezioni del tensore operand
dagli offset specificati in start_indices
e produce un tensore result
.
Il seguente diagramma mostra come gli elementi in result
vengono mappati sugli elementi in
operand
usando un esempio concreto. Il diagramma sceglie alcuni esempi result
indici e spiega in dettaglio a quali indici operand
corrispondono.
Più formalmente, result[result_index] = operand[operand_index]
dove:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
è definito come:start_indices[bi0, ..., :, ..., biN]
dovebi
sono singoli elementi inbatch_index
e:
sono inserite all'indiceindex_vector_dim
, seindex_vector_dim
<rank(start_indices)
,[start_indices[batch_index]]
in caso contrario.
- Per
d_operand
aaxes(operand)
,- .
full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sed_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
in caso contrario.
- Per
d_operand
aaxes(operand)
,- .
full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_operand = operand_batching_dims[i_batching]
ed_start = start_indices_batching_dims[i_batching]
.full_batching_index[d_operand] = 0
in caso contrario.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
doveoi
sono individuali inoffset_index
e0
è inserito in corrispondenza degli indici dacollapsed_slice_dims
eoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Se indices_are_sorted
è true
, l'implementazione può presupporre che
start_indices
vengono ordinati rispetto a start_index_map
, altrimenti i valori
non è definito. Più formalmente, per tutti i i1 < i2
di indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensore di tipo intero | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Costante tensore monodimensionale di tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Costante tensore monodimensionale di tipo si64 |
(C13-C17) |
(I7) | start_index_map |
Costante tensore monodimensionale di tipo si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
costante di tipo si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Costante tensore monodimensionale di tipo si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
costante di tipo i1 |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C5), (C22-C23) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dove:batch_dim_sizes = shape(start_indices)
ad eccezione del fatto che le dimensioni distart_indices
corrispondente aindex_vector_dim
non è incluso.offset_dim_sizes = slice_sizes
tranne che le dimensioni delle dimensioni inslice_sizes
corrispondente acollapsed_slice_dims
eoperand_batching_dims
non sono inclusi.combine
posizionabatch_dim_sizes
sugli assi corrispondenti abatch_dims
eoffset_dim_sizes
in assi corrispondenti aoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Esempi
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantica
Restituisce la dimensione dell'elemento dimension
specificato di operand
. Più formalmente,
result = dim(operand, dimension)
. La semantica riguarda solo la forma
del tipo. Il tipo di elemento può essere qualsiasi cosa.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1) |
(I2) | dimension |
costante di tipo si64 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
Tensore 0dimensionale di tipo si32 |
Vincoli
- (C1)
0 <= dimension < rank(operand)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantica
Estrae l'elemento nella posizione index
della tupla operand
e produce un
result
. Più formalmente, result = operand[index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tupla | (C1), (C2) |
(I2) | index |
costante di tipo si32 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
qualsiasi tipo supportato | (C2) |
Vincoli
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Esempi
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
se
Semantica
Genera l'output dall'esecuzione esattamente di una funzione da true_branch
o
false_branch
a seconda del valore di pred
. Più formalmente, result =
pred ? true_branch() : false_branch()
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
Tensore 0dimensionale di tipo i1 |
|
(I2) | true_branch |
funzione | (C1-C3) |
(I3) | false_branch |
funzione | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori o token quantizzati | (C3) |
Vincoli
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Esempi
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
immaginazione
Semantica
Estrae la parte immaginaria, in termini di elementi, dal operand
e produce una
result
tensore. In modo più formale, per ogni elemento x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di rappresentazione in virgola mobile o di tipo complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore del tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2) Per
element_type(result)
si intende:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
In-feed
Semantica
Legge i dati dal feed Infeed e produce results
.
La semantica di infeed_config
è definita dall'implementazione.
results
è costituito da valori di payload che vengono visualizzati per primi e da un token
per ultimo. In futuro, prevediamo di suddividere il payload e il token in due
separati per migliorare la chiarezza
(N. 670).
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
costante di tipo string |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori o token quantizzati | (C1-C3) |
Vincoli
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Esempi
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semantica
Riempi un tensore output
con valori in ordine crescente partendo da zero
lungo la dimensione iota_dimension
. Più formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
0 <= iota_dimension < rank(output)
.
Esempi
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantica
Esegue la verifica a livello di elemento per verificare se il valore in x
è finito (ovvero non è né l'uno né l'altro
+Inf, -Inf, né NaN) e produce un tensore y
. Implementa il isFinite
il funzionamento secondo la specifica IEEE-754. Per i tipi quantizzati, il risultato è
sempre true
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | x |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
y |
tensore di tipo booleano | (C1) |
Vincoli
- (C1)
shape(x) = shape(y)
.
Esempi
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantica
Esegue un'operazione logaritmica a livello di elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
log
dallo standard IEEE-754. - Per i numeri complessi: logaritmo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(log, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantica
Esegue il logaritmo a livello di elemento più un'operazione sul tensore operand
e
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
logp1
dallo standard IEEE-754. - Per i numeri complessi: logaritmo complesso più uno.
- Per i tipi quantizzati:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistica
Semantica
Esegue un'operazione logistica a livello di elementi sul tensore operand
e produce un'operazione
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
division(1, addition(1, exp(-x)))
dallo standard IEEE-754. - Per i numeri complessi: logistica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(logistic, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mappa
Semantica
Applica una funzione di mappa computation
a inputs
lungo dimensions
e
produce un tensore result
.
Più formalmente, result[result_index] = computation(inputs...[result_index])
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4) |
(I2) | dimensions |
Costante tensore monodimensionale di tipo si64 |
(C3) |
(I3) | computation |
funzione | (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1), (C4) |
Vincoli
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
ha il tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
doveEi = element_type(inputs[i])
eE' = element_type(result)
.
Esempi
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
massimo
Semantica
Esegue un'operazione massima a livello di elemento sui tensori lhs
e rhs
e produce un
tensore result
. A seconda del tipo di elemento:
- Per i valori booleani: OR logico.
- Per i numeri interi: numero massimo.
- Per i numeri in virgola mobile:
maximum
dallo standard IEEE-754. - Per i numeri complessi: lessicografico massimo per la coppia
(real, imaginary)
. Imporre un ordinamento sui numeri complessi implica una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimo
Semantica
Esegue un'operazione minima a livello di elemento sui tensori lhs
e rhs
e produce un
tensore result
. A seconda del tipo di elemento:
- Per i valori booleani: AND logico.
- Per i numeri interi: numero minimo intero.
- Per i numeri in virgola mobile:
minimum
dallo standard IEEE-754. - Per i numeri complessi: minimo lessicografico per la coppia
(real, imaginary)
. Imporre un ordinamento sui numeri complessi implica una semantica sorprendente, quindi in futuro abbiamo intenzione di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
moltiplicazione
Semantica
Esegue il prodotto a livello di elementi di due tensori lhs
e rhs
e produce un
tensore result
. A seconda del tipo di elemento:
- Per i valori booleani: AND logico.
- Per i numeri interi: moltiplicazione di numeri interi.
- Per i numeri in virgola mobile:
multiplication
dallo standard IEEE-754. - Per i numeri complessi: moltiplicazioni complesse.
- Per i tipi quantizzati:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore quantizzato per tensore o per tensore | (C1) |
(I2) | rhs |
tensore quantizzato per tensore o per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negare
Semantica
Esegue la negazione a livello di elemento del tensore operand
e produce un result
tensore. A seconda del tipo di elemento:
- Per i numeri interi firmati: negazione di numeri interi.
- Per i numeri interi senza segno: bitcast a numero intero con segno, negazione di numeri interi, bitcast a un numero intero senza segno.
- Per i numeri in virgola mobile:
negate
dallo standard IEEE-754. - Per i numeri complessi: negazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(negate, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
non
Semantica
Esegue l'operatore NOT a livello di elemento del tensore operand
e produce un tensore result
.
A seconda del tipo di elemento:
- Per i valori booleani: NOT logico.
- Per i numeri interi: NOT a livello di bit.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantica
Garantisce che le operazioni che producono operand
vengano eseguite prima che
che dipendono da result
e impediscono le trasformazioni del compilatore
di spostare le operazioni oltre la barriera. A parte questo, l'operazione
un'identità, ad esempio result = operand
.
Argomenti
Nome | Tipo | Vincoli |
---|---|---|
operand |
numero variadico di tensori, tensori o token quantiizzati per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
numero variadico di tensori, tensori o token quantiizzati per tensore | (C1) |
Vincoli
- (C1)
type(operand...) = type(result...)
.
Esempi
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semantica
Esegue l'operatore OR a livello di elemento di due tensori lhs
e rhs
e produce un result
tensore. A seconda del tipo di elemento:
- Per i valori booleani: OR logico.
- Per i numeri interi: OR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di numeri interi o booleani | (C1) |
(I2) | rhs |
tensore di numeri interi o booleani | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi o booleani | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
uscita
Semantica
Scrive inputs
nell'outfeed e produce un token result
.
La semantica di outfeed_config
è definita dall'implementazione.
Input
Etichetta | Nome | Tipo |
---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati |
(I2) | token |
token |
(I3) | outfeed_config |
costante di tipo string |
Output
Nome | Tipo |
---|---|
result |
token |
Esempi
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Semantica
Espande operand
inserendo una spaziatura interna attorno al tensore e tra gli elementi
del tensore con l'elemento padding_value
specificato.
edge_padding_low
e edge_padding_high
specificano la quantità di spaziatura interna aggiunta
alla fascia bassa (accanto all'indice 0) e alla fascia alta (accanto all'indice più alto) di
ogni dimensione. La quantità di spaziatura interna può essere negativa, laddove il valore
il valore assoluto di spaziatura interna negativa indica il numero di elementi da rimuovere
dalla dimensione specificata.
interior_padding
specifica la quantità di spaziatura interna aggiunta tra due elementi qualsiasi
elementi in ogni dimensione che potrebbero non essere negativi. Viene applicata la spaziatura interna
prima della spaziatura interna dei bordi, in modo che la spaziatura interna negativa dei bordi rimuova gli elementi
l'operando con riempitivo interno.
Più formalmente, result[result_index]
è definito come:
operand[operand_index]
seresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C2), (C4) |
(I2) | padding_value |
tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
(I3) | edge_padding_low |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Costante tensore monodimensionale di tipo si64 |
(C1), (C4) |
(I5) | interior_padding |
Costante tensore monodimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantica
Genera partition_id
del processo attuale.
Output
Nome | Tipo |
---|---|
result |
Tensore 0dimensionale di tipo ui32 |
Esempi
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semantica
Esegue il conteggio a livello di elemento del numero di bit impostato nel tensore operand
e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result)
.
Esempi
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
potenza
Semantica
Esegue l'elevazione a livello di elementi del tensore lhs
per tensore rhs
e
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri interi: esponenzialità di numeri interi.
- Per i numeri in virgola mobile:
pow
dallo standard IEEE-754. - Per i numeri complessi: esponenziale complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
reale
Semantica
Estrae la parte reale, a livello di elemento, da operand
e produce un result
tensore. In modo più formale, per ogni elemento x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di rappresentazione in virgola mobile o di tipo complesso | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore del tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand)
. - (C2) Per
element_type(result)
si intende:complex_element_type(element_type(operand))
seis_complex(operand)
.element_type(operand)
in caso contrario.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantica
Riceve i dati da un canale con channel_id
e genera results
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati dalla
. Altrimenti, i dati vengono trasferiti da un altro dispositivo. Significato
dell'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, quindi in futuro prevediamo di conservarne solo uno
(#666).
results
è costituito da valori di payload che vengono visualizzati per primi e da un token
per ultimo. In futuro, prevediamo di suddividere il payload e il token in due
separati per migliorare la chiarezza
(N. 670).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
costante di tipo si64 |
|
(I3) | channel_type |
enum di DEVICE_TO_DEVICE e HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori o token quantizzati | (C2-C4) |
Vincoli
- (C1)
channel_type
è definito come:HOST_TO_DEVICE
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Esempi
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantica
Applica una funzione di riduzione body
a inputs
e init_values
lungo la
dimensions
e produce results
tensori.
L'ordine delle riduzioni è definito dall'implementazione, il che significa che body
e
init_values
deve formare un monoide per garantire che l'operazione produca il
gli stessi risultati per tutti gli input in tutte le implementazioni. Tuttavia, questa condizione
non è valida per molte riduzioni popolari. Ad es. aggiunta con rappresentazione in virgola mobile
body
e zero per init_values
non formano un monoide perché
L'aggiunta in virgola mobile non è associativa.
Più formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted)
dove:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, dove:
sono inseriti alle oredimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
per un albero binarioschedule
dove:- .
exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
è un albero binario completo, definito dall'implementazione, il cui ordine l'attraversamento è costituito da:input_slices_converted...[index]
valori, per tutti iindex
inindex_space(input_slices_converted)
in ordine lessicografico crescente diindex
.- Intervallo con una quantità definita dall'implementazione di
init_values_converted
in posizioni definite dall'implementazione.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C7) |
(I2) | init_values |
numero variadico di tensori 0-dimensionali o tensori quantizzati per tensore | (C2), (C3) |
(I3) | dimensions |
Costante tensore monodimensionale di tipo si64 |
(C4), (C5), (C7) |
(I4) | body |
funzione | (C6) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C3), (C7), (C8) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
tranne che la dimensione le dimensioniinputs...
corrispondenti adimensions
non sono incluse. - (C8)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantica
Esegue la conversione a livello di elementi di operand
in un altro tipo con rappresentazione in virgola mobile
che utilizza exponent_bits
e mantissa_bits
e torna all'originale
un tipo a virgola mobile e produce un tensore output
.
In modo più formale:
- Le mantissa del valore originale vengono aggiornate per arrotondare l'originale
al valore più vicino rappresentabile con
mantissa_bits
utilizzando semantica diroundToIntegralTiesToEven
. - Poi, se il valore di
mantissa_bits
è inferiore al numero di mantissa di il valore originale, i bit di mantissa vengono troncati amantissa_bits
. - Quindi, se i bit esponenti del risultato intermedio non rientrano nella
intervallo fornito da
exponent_bits
, il risultato intermedio va oltre la all'infinito utilizzando il segno originale oppure si sposta verso zero utilizzando firma originale. - Per i tipi quantizzati, esegue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
(I2) | exponent_bits |
costante di tipo si32 |
(C2) |
(I3) | mantissa_bits |
costante di tipo si32 |
(C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Esempi
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantica
All'interno di ciascun gruppo di processi nella griglia dei processi StableHLO, esegue una riduzione,
usando computations
, sui valori del tensore operand
di ogni processo,
divide il risultato della riduzione in scatter_dimension
in parti e disperde
le parti divise tra i processi per produrre result
.
L'operazione divide la griglia del processo StableHLO in process_groups
, che
definiti come segue:
cross_replica(replica_groups)
sechannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sechannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sechannel_id > 0 and use_global_device_ids = true
.
In seguito, all'interno di ogni process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
per tutti isender
inprocess_group
, dovereceiver_index = process_group.index(receiver)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
costante di tipo si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C3-C5) |
(I4) | channel_id |
costante di tipo si64 |
(C6) |
(I5) | use_global_device_ids |
costante di tipo i1 |
(C6) |
(I6) | computation |
funzione | (C7) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C8-C9) |
Vincoli
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4) Per
size(replica_groups)
si intende:num_replicas
se si utilizzacross_replica
.num_replicas
se si utilizzacross_replica_and_partition
.num_processes
se si utilizzaflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Se
use_global_device_ids = true
, allorachannel_id > 0
. - (C7)
computation
ha il tipo(tensor<E>, tensor<E>) -> (tensor<E>)
doveis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
eccetto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantica
Applica una funzione di riduzione body
alle finestre di inputs
e init_values
e produce results
.
Il seguente diagramma mostra come vengono calcolati gli elementi in results...
a partire da
inputs...
usando un esempio concreto.
Più formalmente,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(vedi Riduci) dove:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
numero variadico di tensori 0-dimensionali o tensori quantizzati per tensore | (C1), (C13) |
(I3) | window_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Costante tensore monodimensionale di tipo si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Costante tensore monodimensionale di tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Costante tensore bidimensionale di tipo si64 |
(C12), (C15) |
(I8) | body |
funzione | (C13) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C14-C16) |
Vincoli
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
doveis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
dove:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semantica
Esegue il resto degli elementi del dividendo lhs
e dei tensori di divisore rhs
e
produce un tensore result
.
Più formalmente, il segno del risultato viene dedotto dal dividendo,
il valore assoluto del risultato è sempre inferiore al valore assoluto del divisore.
Il resto viene calcolato come lhs - d * rhs
, dove d
è dato da:
- Per i numeri interi:
stablehlo.divide(lhs, rhs)
. - Per i numeri in virgola mobile:
division(lhs, rhs)
da IEEE-754 con attributo di arrotondamentoroundTowardZero
. - Per i numeri complessi: da definire (N. 997).
- Per i tipi quantizzati:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Per i tipi di elementi con rappresentazione in virgola mobile, questa operazione è in contrasto con la
remainder
in base alla specifica IEEE-754, in cui d
è un valore integrale
più vicino al valore esatto di lhs/rhs
con legami a pari.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di numeri interi, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantica
Genera replica_id
del processo attuale.
Output
Nome | Tipo |
---|---|
result |
Tensore 0dimensionale di tipo ui32 |
Esempi
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
rimodellare
Semantica
Esegue la rimodellamento del tensore operand
in un tensore result
. Concettualmente,
equivale a mantenere la stessa rappresentazione canonica, ma potenzialmente modificando
la forma, ad esempio da tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
Più formalmente, result[result_index] = operand[operand_index]
dove
result_index
e operand_index
hanno la stessa posizione nella lessicografica
nell'ordine di index_space(result)
e index_space(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1-C3) |
Vincoli
- (C1) Il valore
element_type(result)
è fornito da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
equantization_dimension(result)
potrebbe variare.
- (C2)
size(operand) = size(result)
. - (C3) Se
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
inverti
Semantica
Inverte l'ordine degli elementi in operand
lungo il valore dimensions
specificato
e produce un tensore result
. Più formalmente,
result[result_index] = operand[operand_index]
dove:
operand_index[d] = dim(result, d) - result_index[d] - 1
sed
indimensions
.operand_index[d] = result_index[d]
in caso contrario.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
(I2) | dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C3) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1), (C3) |
Vincoli
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Esempi
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
Rng
Semantica
Genera numeri casuali utilizzando l'algoritmo rng_distribution
e produce una
result
tensore di una determinata forma shape
.
Se rng_distribution = UNIFORM
, i numeri casuali vengono generati
seguendo la distribuzione uniforme sull'intervallo [a, b)
. Se a >= b
,
il comportamento è indefinito.
Se rng_distribution = NORMAL
, i numeri casuali vengono generati
secondo la distribuzione normale con media = a
e deviazione standard = b
.
Con b < 0
, il comportamento non è definito.
Il modo esatto in cui vengono generati i numeri casuali è definito dall'implementazione. Per essere deterministici, e possono o meno utilizzare stato nascosto.
Nelle conversazioni con molti stakeholder, questa operazione si è rivelata efficace ritirato, pertanto in futuro abbiamo intenzione di valutarne la rimozione (n. 597).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
Tensione 0-dimensionale di tipo intero, booleano o con virgola mobile | (C1), (C2) |
(I2) | b |
Tensione 0-dimensionale di tipo intero, booleano o con virgola mobile | (C1), (C2) |
(I3) | shape |
Costante tensore monodimensionale di tipo si64 |
(C3) |
(I4) | rng_distribution |
enum di UNIFORM e NORMAL |
(C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, booleani o in virgola mobile | (C1-C3) |
Vincoli
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Se
rng_distribution = NORMAL
, allorais_float(a)
. - (C3)
shape(result) = shape
.
Esempi
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantica
Restituisce un valore output
riempito con bit casuali uniformi e uno stato dell'output aggiornato
output_state
utilizza l'algoritmo del generatore di numeri pseudocasuali rng_algorithm
dato uno stato iniziale initial_state
. Viene garantito che l'output sia
funzione deterministica di initial_state
, ma non è garantito che sia
deterministici tra le implementazioni.
rng_algorithm
corrisponde a uno dei seguenti:
DEFAULT
: algoritmo definito dall'implementazione.THREE_FRY
: variante dell'algoritmo Threefry definita dall'implementazione.*PHILOX
: variante dell'algoritmo Philox definita dall'implementazione.*
* Vedi: Salmon et al. SC 2011. Numeri casuali paralleli: 1, 2, 3. di Gemini Advanced.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | rng_algorithm |
enum di DEFAULT , THREE_FRY e PHILOX |
(C2) |
(I2) | initial_state |
Tensore monodimensionale di tipo ui64 |
(C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
output_state |
Tensore monodimensionale di tipo ui64 |
(C1) |
output |
tensore di numeri interi o in virgola mobile |
Vincoli
- (C1)
type(initial_state) = type(output_state)
. - (C2) Per
size(initial_state)
si intende:- l'implementazione definita se
rng_algorithm = DEFAULT
. 2
serng_algorithm = THREE_FRY
.2
o3
serng_algorithm = PHILOX
.
- l'implementazione definita se
Esempi
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantica
Esegue l'arrotondamento a livello di elemento verso il numero intero più vicino, separando i legami
da zero, sul tensore operand
, e produce un tensore result
. Implementa
l'operazione roundToIntegralTiesToAway
dalla specifica IEEE-754. Per
tipi quantizzati, esegue
dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantica
Esegue l'arrotondamento degli elementi verso il numero intero più vicino, spezzando i legami
verso il numero intero pari, sul tensore operand
e produce un result
tensore. Implementa l'operazione roundToIntegralTiesToEven
da IEEE-754
e la specifica del prodotto. Per i tipi quantizzati, esegue
dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantica
Esegue un'operazione di radice quadrata reciproca a livello di elemento sul tensore operand
e
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
rSqrt
dallo standard IEEE-754. - Per i numeri complessi: radice quadrata reciproca complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
dispersione
Semantica
Produce results
tensori uguali a inputs
tensori tranne che
diverse sezioni specificate da scatter_indices
vengono aggiornate con i valori
updates
tramite update_computation
.
Il seguente diagramma mostra come gli elementi in updates...
vengono mappati sugli elementi in
results...
usando un esempio concreto. Il diagramma sceglie alcuni esempi
updates...
e spiega in dettaglio quali indici results...
utilizzano
a cui corrispondono.
Più formalmente, per tutti i update_index
in index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
è definito come:scatter_indices[si0, ..., :, ..., siN]
dovesi
sono individuali inupdate_scatter_index
e:
sia inserita Indiceindex_vector_dim
, seindex_vector_dim
<rank(scatter_indices)
,[scatter_indices[update_scatter_index]]
in caso contrario.
- Per
d_input
aaxes(inputs[0])
,- .
full_start_index[d_input] = start_index[d_start]
sed_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
in caso contrario.
- Per
d_input
aaxes(inputs[0])
,- .
full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sed_input = input_batching_dims[i_batching]
ed_start = scatter_indices_batching_dims[i_batching]
.full_batching_index[d_input] = 0
in caso contrario.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
dovewi
sono individuali inupdate_window_index
e0
è inserito in corrispondenza degli indici dainserted_window_dims
einput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Detto questo, results = exec(schedule, inputs)
, dove:
schedule
è una permutazione definita dall'implementazioneindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
dove:- Se
result_index
è entro i limiti dishape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
è una copia diresults
conresults...[result_index]
impostato suupdated_values...
.- In caso contrario
updated_results = results
.
- Se
exec([], results) = results
.
Se indices_are_sorted
è true
, l'implementazione può presupporre che
scatter_indices
sono ordinati rispetto a scatter_dims_to_operand_dims
,
altrimenti il comportamento è indefinito. Più formalmente, per tutti i i1 < i2
da
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Se unique_indices
è true
, l'implementazione può presupporre che tutti
Gli indici result_index
distribuiti sono univoci. Se unique_indices
è
true
ma gli indici sparsi su non sono univoci, il comportamento è
non definito.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensore di tipo intero | (C4), (C15), (C19), (C22) |
(I3) | updates |
numero variadico di tensori o tensori quantizzati per tensore | (C3-C6), (C8) |
(I4) | update_window_dims |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Costante tensore monodimensionale di tipo si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Costante tensore monodimensionale di tipo si64 |
(C19-C21) |
(I9) | index_vector_dim |
costante di tipo si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
costante di tipo i1 |
|
(I11) | unique_indices |
costante di tipo i1 |
|
(I12) | update_computation |
funzione | (C23) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C24-C25) |
Vincoli
- (C1)
same(shape(inputs...))
. - (C2) "rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)"
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
dove:update_scatter_dim_sizes = shape(scatter_indices)
tranne che la dimensione discatter_indices
corrispondenteindex_vector_dim
non è incluso.update_window_dim_sizes <= shape(inputs[0])
tranne che le dimensioni ininputs[0]
corrispondenti ainserted_window_dims
einput_batching_dims
non sono inclusi.combine
posizionaupdate_scatter_dim_sizes
sugli assi corrispondenti aupdate_scatter_dims
eupdate_window_dim_sizes
negli assi corrispondenti aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
ha il tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, doveis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
per tutti ii
in[0,N)
.
Esempi
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
seleziona
Semantica
Produce un tensore result
in cui ogni elemento è selezionato da on_true
o
tensore on_false
in base al valore dell'elemento corrispondente di pred
.
Più formalmente, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, dove pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Per i tipi quantizzati, esegue
dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | pred |
tensore di tipo i1 |
(C1) |
(I2) | on_true |
tensore quantizzato per tensore o per tensore | (C1-C2) |
(I3) | on_false |
tensore quantizzato per tensore o per tensore | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C2) |
Vincoli
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Esempi
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantica
Disperde i valori dal tensore source
utilizzando scatter
in base al
risultato di reduce_window
del tensore input
usando select
e produce
un tensore result
.
Il seguente diagramma mostra come vengono calcolati gli elementi in result
a partire da
operand
e source
utilizzando un esempio concreto.
In modo più formale:
selected_values = reduce_window_without_init(...)
con i seguenti input:inputs = [operand].
window_dimensions
,window_strides
epadding
utilizzati così come sono.base_dilations = windows_dilations = 1
.body
è definito come:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
dove funzionano
E = element_type(operand)
ereduce_window_without_init
esattamente comereduce_window
, ad eccezione del fatto cheschedule
del valore sottostantereduce
(vedi Riduci) non include valori init. Attualmente cosa succede se la finestra corrispondente non ha valori (n. 731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
dove:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
seselected_values[source_index]
ha l'elementooperand
daoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensore quantizzato per tensore o per tensore | (C1), (C2) |
(I3) | init_value |
tensore 0-dimensionale o tensore quantizzato per tensore | (C3) |
(I4) | window_dimensions |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Costante tensore monodimensionale di tipo si64 |
(C2), (C6), (C7) |
(I6) | padding |
Costante tensore bidimensionale di tipo si64 |
(C2), (C8) |
(I7) | select |
funzione | (C9) |
(I8) | scatter |
funzione | (C10) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C11-C12) |
Vincoli
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
dove:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
ha il tipo(tensor<E>, tensor<E>) -> tensor<i1>
doveE = element_type(operand)
. - (C10)
scatter
ha il tipo(tensor<E>, tensor<E>) -> tensor<E>
doveis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Esempi
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
Invia
Semantica
Invia inputs
a un canale channel_id
e produce un token result
.
Se is_host_transfer
è true
, l'operazione trasferisce i dati all'account
. Altrimenti, i dati vengono trasferiti su un altro dispositivo. Significato
dell'implementazione. Questo flag duplica le informazioni fornite in
channel_type
, quindi in futuro prevediamo di conservarne solo uno
(#666).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati | |
(I2) | token |
token |
|
(I3) | channel_id |
costante di tipo si64 |
|
(I4) | channel_type |
enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
costante di tipo i1 |
(C1) |
Output
Nome | Tipo |
---|---|
result |
token |
Vincoli
- (C1)
channel_type
è definito come:DEVICE_TO_HOST
seis_host_transfer = true
,DEVICE_TO_DEVICE
in caso contrario.
Esempi
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantica
Esegue l'operazione di spostamento a sinistra dell'elemento sul tensore lhs
per un numero di rhs
di bit e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantica
Esegue l'operazione di spostamento a destra aritmetica degli elementi sul tensore lhs
tramite
rhs
di bit e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantica
Esegue l'operazione logico di spostamento a destra a livello di elemento sul tensore lhs
in base a rhs
numero di bit e produce un tensore result
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di tipo intero | (C1) |
(I2) | rhs |
tensore di tipo intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
firmare
Semantica
Restituisce il segno dell'elemento operand
e produce un tensore result
.
Più formalmente, per ogni elemento x
, la semantica può essere espressa utilizzando
Sintassi Python come segue:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Per i tipi quantizzati, esegue
dequantize_op_quantize(sign, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di numero intero firmato, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numero intero firmato, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semantica
Esegue un'operazione seno a livello di elemento sul tensore operand
e produce un result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
sin
dallo standard IEEE-754. - Per i numeri complessi: seno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(sine, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantica
Estrae una sezione da operand
utilizzando indici iniziali calcolati in modo statico
e produce un tensore result
. start_indices
contengono gli indici iniziali di
la sezione per ogni dimensione, limit_indices
contiene gli indici finali
(esclusivo) per la sezione per ogni dimensione e strides
contiene i passi
per ogni dimensione.
Più formalmente, result[result_index] = operand[operand_index]
dove
operand_index = start_indices + result_index * strides
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato per tensore o per tensore | (C1-C3), (C5) |
(I2) | start_indices |
Costante tensore monodimensionale di tipo si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Costante tensore monodimensionale di tipo si64 |
(C2), (C3), (C5) |
(I4) | strides |
Costante tensore monodimensionale di tipo si64 |
(C2), (C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato per tensore o per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Esempi
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
ordinare
Semantica
Ordina le sezioni monodimensionali di inputs
nella dimensione dimension
insieme,
secondo comparator
e produce results
.
A differenza di input simili in altre operazioni, dimension
consente valori negativi,
con la semantica descritta di seguito. In futuro, ciò potrebbe non essere consentito
per motivi di coerenza
(N. 1377).
Se is_stable
è vero, l'ordinamento è stabile, ossia l'ordine relativo dei
gli elementi considerati uguali dal comparatore vengono conservati. Per il caso
in cui è presente un solo input, due elementi e1
e e2
sono considerati
uguale a quello del comparatore se e solo se
comparator(e1, e2) = comparator(e2, e1) = false
. Consulta la formalizzazione di seguito
per la generalizzazione su più input.
Più formalmente, per tutti i result_index
in index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
doveriN
sono individuali inresult_index
, mentre:
è inserito inadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- dove
sort
ordina una sezione unidimensionale in ordine non decrescente in attesa checomparator_together
restituiscetrue
se l'argomento lato sinistro è meno dell'argomento della seconda mano destra. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | inputs |
numero variadico di tensori o tensori quantizzati per tensore | (C1-C5) |
(I2) | dimension |
costante di tipo si64 |
(C4) |
(I3) | is_stable |
costante di tipo i1 |
|
(I4) | comparator |
funzione | (C5) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori o tensori quantizzati per tensore | (C2), (C3) |
Vincoli
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, doveR = rank(inputs[0])
. - (C5)
comparator
ha un tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, doveEi = element_type(inputs[i])
.
Esempi
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantica
Esegue un'operazione di radice quadrata dell'elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
squareRoot
dallo standard IEEE-754. - Per i numeri complessi: radice quadrata complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(sqrt, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantica
Esegue la sottrazione a livello di elemento di due tensori lhs
e rhs
e produce un
tensore result
. A seconda del tipo di elemento:
- Per i numeri interi: sottrazione di numeri interi.
- Per i numeri in virgola mobile:
subtraction
dallo standard IEEE-754. - Per i numeri complessi: sottrazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
(I2) | rhs |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di numeri interi, in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Esempi
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantica
Esegue un'operazione di tangente a livello di elemento sul tensore operand
e produce un
result
tensore. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
tan
dallo standard IEEE-754. - Per i numeri complessi: tangente complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tan, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
Tanh
Semantica
Esegue un'operazione di tangente iperbolica a livello di elemento sul tensore operand
e
produce un tensore result
. A seconda del tipo di elemento:
- Per i numeri in virgola mobile:
tanh
dallo standard IEEE-754. - Per i numeri complessi: tangente iperbolica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tanh, operand, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result)
.
Esempi
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
trasponi
Semantica
Disattiva le dimensioni del tensore operand
utilizzando permutation
e produce una
tensore result
. Più formalmente, result[result_index] = operand[operand_index]
dove result_index[d] = operand_index[permutation[d]]
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore o tensore quantizzato | (C1-C4) |
(I2) | permutation |
Costante tensore monodimensionale di tipo si64 |
(C2-C4) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3-C4) |
Vincoli
- (C1) Il valore
element_type(result)
è fornito da:element_type(operand)
, se!is_per_axis_quantized(operand)
.element_type(operand)
eccettoquantization_dimension(operand)
equantization_dimension(result)
potrebbe variare.
- (C2)
permutation
è una permutazione dirange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Se
is_per_axis_quantized(result)
:quantization_dimension(operand) = permutation(quantization_dimension(result))
.
Esempi
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantica
Risolvi i batch di sistemi di equazioni lineari con triangolari inferiori o superiori matrici dei coefficienti.
Più formalmente, sulla base di a
e b
, result[i0, ..., iR-3, :, :]
è la soluzione
a op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
quando left_side
è
true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
quando
left_side
è false
, risolvendo per la variabile x
in cui è determinata op(a)
per transpose_a
, che può essere uno dei seguenti:
NO_TRANSPOSE
: esegui l'operazione utilizzandoa
così com'è.TRANSPOSE
: esegui l'operazione sulla trasposizione dia
.ADJOINT
: esegui l'operazione sulla trasposizione del coniugato dia
.
I dati di input vengono letti solo dal triangolo inferiore a
, se lower
è true
o
triangolo in alto di a
, altrimenti. I dati di output vengono restituiti nello stesso triangolo.
i valori nell'altro triangolo sono definiti dall'implementazione.
Se unit_diagonal
è vero, l'implementazione può presupporre che la diagonale
gli elementi di a
sono uguali a 1, altrimenti il comportamento è indefinito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | a |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1-C3) |
(I2) | b |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1-C4) |
(I3) | left_side |
costante di tipo i1 |
(C3) |
(I4) | lower |
costante di tipo i1 |
|
(I5) | unit_diagonal |
costante di tipo i1 |
|
(I6) | transpose_a |
enum di NO_TRANSPOSE , TRANSPOSE e ADJOINT |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di tensore in virgola mobile o di tipo complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relazione tra
shape(a)
eshape(b)
è definita come segue:- .
shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Esempi
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tupla
Semantica
Genera una tupla result
dai valori val
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | val |
numero variadico di valori | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tupla | (C1) |
Vincoli
- (C1)
result
ha il tipotuple<E0, ..., EN-1>
doveEi = type(val[i])
.
Esempi
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantica
Esegue la conversione a livello di elementi del tensore quantizzato operand
in un
tensore a virgola mobile result
in base ai parametri di quantizzazione definiti
in base al tipo operand
.
Più formalmente, result = dequantize(operand)
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore del tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Esempi
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantica
Esegue la conversione a livello di elementi di un tensore a virgola mobile o di un tensore quantizzato
operand
a un tensore quantizzato result
in base alla quantizzazione
definiti dal tipo result
.
Più formalmente,
- Se
is_float(operand)
:result = quantize(operand, type(result))
.
- Se
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
tensore di tipo in virgola mobile o quantizzato | (C1), (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore quantizzato | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Esempi
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mentre
Semantica
Produce l'output dall'esecuzione della funzione body
0 o più volte mentre
La funzione cond
restituisce true
. Più formalmente, la semantica può essere espressa
utilizzando la sintassi Python come segue:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Il comportamento di un ciclo infinito è da definire (n. 383).
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | operand |
numero variadico di tensori, tensori o token quantizzati | (C1-C3) |
(I2) | cond |
funzione | (C1) |
(I3) | body |
funzione | (C2) |
Output
Nome | Tipo | Vincoli |
---|---|---|
results |
numero variadico di tensori, tensori o token quantizzati | (C3) |
Vincoli
- (C1)
cond
ha il tipo(T0, ..., TN-1) -> tensor<i1>
, doveTi = type(operand[i])
. - (C2)
body
ha il tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, doveTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Esempi
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Semantica
Esegue un'XOR a livello di elemento di due tensori lhs
e rhs
e produce un result
tensore. A seconda del tipo di elemento:
- Per i valori booleani: XOR logico.
- Per i numeri interi: XOR a livello di bit.
Input
Etichetta | Nome | Tipo | Vincoli |
---|---|---|---|
(I1) | lhs |
tensore di un tipo booleano o intero | (C1) |
(I2) | rhs |
tensore di un tipo booleano o intero | (C1) |
Output
Nome | Tipo | Vincoli |
---|---|---|
result |
tensore di un tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result)
.
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperabilità dei dialetti
Al momento, i programmi StableHLO in circolazione a volte contengono operazioni che non sono definite da StableHLO.
Modulo, funzione, chiamata e ritorno
StableHLO utilizza operazioni MLIR upstream per ModuleOp, FuncOp, CallOp e Restituzione Ciò è stato fatto per migliorare l'interoperabilità con i macchinari MLIR esistenti, documenti utili sono scritti destinati a FuncOp e ModuleOp e molte compilazioni delle pipeline prevedono la presenza di queste operazioni. Le garanzie di compatibilità completa sono applicate a queste operazioni. Se in queste operazioni si verificano cambiamenti in una in un modo incompatibile (ovvero la rimozione), verranno aggiunti gli equivalenti StableHLO per preservare la compatibilità.
CHLO
L'opset CHLO contiene operazioni di livello superiore che si decompongono in StableHLO. Al momento non esistono garanzie di compatibilità per CHLO. Per compatibilità garantiti, il pass chlo-legalize-to-stablehlo deve essere utilizzato prima della serializzazione.
Operazioni forma
È un caso d'uso comune nella community quello di utilizzare determinate operazioni dall'account
Dialetti MLIR nei programmi dinamici StableHLO per eseguire calcoli di forma.
Più comunemente, si tratta del dialetto shape
operazioni come shape_of
o num_elements
, tensor
dialetto
operazioni come dim
o from_elements
e il tipo index
integrato.
Il documento Dynamism RFC > O2
li denota come fuori ambito, ma alcune funzionalità di supporto per i tipi index
sono
inclusi per scopi di interoperabilità. Non vi sono garanzie di compatibilità per questi
operazioni o tipi di query. Il parametro shape-legalize-to-stablehlo
è possibile usare un pass per convertire queste operazioni in operazioni StableHLO completamente supportate.
Operazioni deprecate
Esistono diverse operazioni StableHLO ereditate da MHLO che sono deprecati e stanno per uscire da StableHLO. I dettagli completi su questi in StableHLO v1.0 Cleanup #2283 (Pulizia StableHLO v1.0 n. 2283). Il problema del tracker per questi ritiri è #2340.
Queste operazioni rientrano in alcune categorie:
- "Non in HLO" delle operazioni StableHLO, che inizialmente facevano parte
l'opset StableHLO, ma in seguito si è ritenuto che non si adattasse bene:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
unary_einsum
(n. 3). - Operazioni inutilizzate: queste operazioni potrebbero essere state utili a un certo punto, ma
erano sottosviluppate oppure le pipeline che utilizzano queste operazioni sono state
con il refactoring in modo che non
siano più richieste. Sono inclusi
map
,tuple
(#598),get_tuple_element
,rng
,complex
confronti #560, e la convoluzionewindow_reversal
(#1181).
Alcune di queste operazioni possono essere rimosse facilmente poiché possono essere espresse utilizzando
operazioni esistenti (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) e verranno rimossi al termine della finestra di compatibilità esistente
(6 mesi). Altre sono ancora in fase di indagine per la rimozione (einsum
,
get_tuple_element
, map
, rng
torch_index_select
, tuple
e complex
confronti, window_reversal
). In attesa del feedback della community,
queste operazioni verranno rimosse o aggiunte alle specifiche con il supporto completo. Fino al giorno
le operazioni future sono noti, la compatibilità è garantita solo per 6 mesi.
Esecuzione
Esecuzione sequenziale
Viene eseguito un programma StableHLO fornendo valori di input alla funzione main
e il calcolo dei valori di output. I valori di output di una funzione vengono calcolati
eseguendo il grafico delle operazioni rooted nell'operazione return
corrispondente.
L'ordine di esecuzione è definito dall'implementazione purché sia allineato con
Dataflow, ovvero se le operazioni vengono eseguite prima dei loro utilizzi. In StableHLO,
le operazioni collaterali consumano un token e producono un token (più token possono
essere multiplexato in un unico token tramite after_all
), quindi l'ordine di esecuzione del lato
degli effetti è in linea con Dataflow. Ad esempio, nel programma riportato di seguito
sono possibili due ordini di esecuzione: %0
→ %1
→ %2
→ return
e
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Più formalmente, un processo SttableHLO è una combinazione di:
1) un programma StableHLO, 2) stati dell’operazione (non ancora eseguito,
già eseguito) e 3) i valori intermedi su cui sta lavorando il processo.
Il processo inizia con i valori di input per la funzione main
, progredisce
grafico delle operazioni che aggiornano gli stati delle operazioni e i valori intermedi e
termina con i valori di output. Un'ulteriore formalizzazione è da definire
(N. 484).
Esecuzione parallela
I programmi HLO stabili possono essere eseguiti in parallelo, organizzati in una griglia di processi 2D
di num_replicas
per num_partitions
, entrambi di tipo ui32
.
Nella griglia del processo StableHLO, num_replicas * num_partitions
di StableHLO
processi vengono eseguiti contemporaneamente. Ogni processo ha un unico
process_id = (replica_id, partition_id)
, dove
replica_id
a replica_ids = range(num_replicas)
e
partition_id
in partition_ids = range(num_partitions)
che hanno entrambi
tipo ui32
.
La dimensione della griglia dei processi è nota in modo statico per ogni programma (nel
in futuro, abbiamo intenzione di renderlo parte esplicita dei programmi StableHLO
#650) e la posizione
all'interno della griglia dei processi è nota in modo statico per ogni processo. Ogni processo ha
alla sua posizione all'interno della griglia del processo tramite replica_id
e
partition_id
operazioni
All'interno della griglia dei processi, i programmi possono essere tutti gli stessi (nel menu a discesa programma, più dati" ), possono essere tutti diversi (nella sezione "Più programmi, Più dati" stilizzato) o una via di mezzo. In futuro, pianificheremo per introdurre il supporto di altri modi di definire programmi StableHLO paralleli, incluso GSPMD (#619).
All'interno della griglia dei processi, i processi sono per lo più indipendenti tra loro: hanno stati di operazione separati, valori separati di input/intermedio/output e la maggior parte delle operazioni viene eseguita separatamente tra i processi, con ad eccezione di un numero ridotto di operazioni collettive descritte di seguito.
Dato che l'esecuzione della maggior parte delle operazioni utilizza solo valori dello stesso
di solito, di solito fare riferimento a questi valori con il loro nome è inequivocabile.
Tuttavia, quando si descrive la semantica delle operazioni collettive, ciò è insufficiente,
che dà origine alla notazione name@process_id
per fare riferimento al valore name
in un determinato processo. (Da questo punto di vista, name
non qualificato può essere
considerata una forma abbreviata di name@(replica_id(), partition_id())
).
L'ordine di esecuzione nei vari processi è definito dall'implementazione, ad eccezione delle sincronizzazione introdotta dalla comunicazione point-to-point e dalle operazioni collettive come descritto di seguito.
Comunicazione point-to-point
I processi HLO stabili possono comunicare tra loro
Canali stabileHLO. Un canale è rappresentato da un ID positivo di tipo
si64
. Attraverso varie operazioni, è possibile inviare valori ai canali e
che li ricevono dai canali.
Ulteriore formalizzazione, ad es. da dove provengono questi ID canale, se i programmi ne vengono a conoscenza e qual è il tipo di da loro presentati, è da definire (N. 484).
Comunicazione in streaming
Ogni processo StableHLO ha accesso a due interfacce di flusso:
- Infeed che possono essere letti.
- Outfeed in cui è possibile scrivere.
A differenza dei canali, che vengono utilizzati per comunicare tra i processi e, di conseguenza, hanno processi in entrambi i lati, gli infeed e gli outfeed hanno gli altri o terminare l'implementazione.
Ulteriore formalizzazione, ad es. come la comunicazione in streaming influenza l'esecuzione dell'ordine e del tipo di sincronizzazione che introduce, è da definire (N. 484).
Operazioni collettive
Ci sono sei operazioni collettive in StableHLO: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
e
reduce_scatter
, Tutte queste operazioni suddividono i processi nel processo StableHLO
in gruppi di processi SttableHLO ed eseguire un calcolo congiunto all'interno
ogni gruppo di processi, indipendentemente dagli altri gruppi di processi.
All'interno di ciascun gruppo di processi, le operazioni collettive possono introdurre ostacolo. Ulteriore formalizzazione, ad es. elaborando su quando esattamente avviene la sincronizzazione, il modo in cui i processi raggiungono questa barriera, e cosa succede se non lo fanno, è da definire (N. 484).
Se il gruppo di processi comporta una comunicazione tra partizioni, ovvero ci sono
nel gruppo di processi con ID di partizione diversi, quindi l'esecuzione
dell'operazione collettiva ha bisogno di un canale e l'operazione collettiva deve fornire
channel_id
positivo di tipo si64
. La comunicazione tra repliche non richiede
canali.
I calcoli eseguiti dalle operazioni collettive sono specifici per le singole operazioni e sono descritti nelle singole sezioni operative precedenti. Tuttavia, le strategie La griglia di processo è suddivisa in gruppi di processi, che vengono condivisi tra queste operazioni descritti in questa sezione. Più formalmente, StableHLO supporta seguendo quattro strategie.
cross_replica
Solo le comunicazioni con repliche diverse avvengono all'interno di ciascun gruppo di processi. Questo
la strategia richiede replica_groups
, un elenco di elenchi di ID replica, e calcola
un prodotto cartesiano di replica_groups
di partition_ids
. replica_groups
devono contenere elementi univoci e coprire tutti i replica_ids
. In modo più formale, usando
Sintassi Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica
produrrà
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Solo le comunicazioni tra partizioni avvengono all'interno di ciascun gruppo di processi. Questo
la strategia richiede partition_groups
, un elenco di elenchi di ID partizione, e
calcola un prodotto cartesiano di partition_groups
per replica_ids
.
partition_groups
deve contenere elementi unici e coprire tutti i partition_ids
.
In modo più formale, usando la sintassi Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per partition_groups = [[0, 1]]
e num_replicas = 4
,
cross_partition
produrrà
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Le comunicazioni cross-replica e cross-partition possono avvenire all'interno di
di un gruppo di processi. Questa strategia richiede replica_groups
, un elenco di elenchi di
ID replica e calcola i prodotti cartesiani di ogni replica_group
in base a
partition_ids
. replica_groups
deve contenere elementi unici e coprire tutti
replica_ids
. In modo più formale, usando la sintassi Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]]
e num_partitions = 2
,
cross_replica_and_partition
produrrà
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Questa strategia richiede flattened_id_groups
, un elenco di elenchi di tipo "lineato"
ID di processo nel formato replica_id * num_partitions + partition_id
e
li trasforma in ID di processo. flattened_id_groups
deve avere elementi univoci
e copre tutti i process_ids
. In modo più formale, usando la sintassi Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
e num_partitions = 2
, flattened_ids
produrrà
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Accuratezza
Al momento, StableHLO non fornisce garanzie in merito all'accuratezza numerica, ma la situazione potrebbe cambiare in futuro (N. 1156).
semantica dell'esecuzione dell'operazione quantizzata
L'interpretazione delle operazioni StableHLO quantizzate può variare a seconda requisiti hardware e delle funzionalità. Ad esempio, alcuni hardware potrebbero optare per interpretare le operazioni quantizzate utilizzando un comando "dequantizza ed esegui in virgola mobile operazioni e infine quantificare" strategia. Altri potrebbero eseguire l'intero processo con l'aritmetica di numeri interi. Di conseguenza, l'interpretazione le operazioni StableHLO quantizzate sono determinate esclusivamente dallo specifico implementazione. L'interpretazione della quantizzazione ibrida (#1575) deve essere basata su la sua semantica come prescritto nella specifica (tramite 1792).
Errori
I programmi StableHLO vengono convalidati attraverso un ampio set di vincoli per operazioni individuali, escludendo molte classi di errori prima della fase di esecuzione. Tuttavia, sono comunque possibili condizioni di errore, ad esempio tramite overflow di numeri interi, accessi oltre i limiti e così via. Se non vengono esplicitamente indicati, tutti questi errori determinare un comportamento definito dall'implementazione, ma potrebbe cambiare (#1157).
Eccezioni con virgola mobile
Come eccezione a questa regola, le eccezioni con virgola mobile nei programmi StableHLO
abbiano un comportamento ben definito. Operazioni che comportano eccezioni definite dal
Standard IEEE-754 (operazione non valida, divisione per zero, overflow, underflow o
eccezioni inesatte) producono risultati predefiniti (come definiti nello standard) e
continuare l'esecuzione senza aumentare il flag di stato corrispondente; simile a
raiseNoFlag
gestione delle eccezioni rispetto allo standard. Eccezioni per le richieste non standard
operazioni (come l'aritmetica complessa e alcune funzioni trascendentali)
dell'implementazione.
Forma non corrispondente
StableHLO supporta tensori di forma dinamica. Tuttavia, le forme devono concordare il runtime, altrimenti il comportamento è indefinito. StabileHLO non dichiara esplicitamente fornisce un'operazione che può affermare che un tensore ha una determinata forma in fase di esecuzione. La generazione di codice corretto è responsabilità del producer.
Come esempio specifico, il programma riportato di seguito è valido. Tuttavia, in fase di runtime,
le forme esatte di %arg0
e %arg1
dovranno essere le stesse, altrimenti i
del programma non è definito:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notazione
Per descrivere la sintassi, questo documento utilizza la versione ISO modificata di EBNF
(ISO/IEC 14977:1996,
Wikipedia),
con due modifiche: 1) le regole vengono definite utilizzando ::=
anziché =
;
2) la concatenazione viene espressa tramite giustapposizione anziché ,
.
Per descrivere la semantica (ad esempio all'interno delle sezioni "Tipi", "Costanti" e "Ops"), utilizziamo formule basate sulla sintassi Python estesa con il supporto per descrivere in modo conciso le operazioni sugli array come descritto di seguito. Funziona bene per piccoli snippet di codice, ma in rari casi quando vengono creati snippet di codice più grandi necessaria, usiamo la sintassi Python vanilla, che viene sempre presentata esplicitamente.
Formule
Vediamo come funzionano le formule sulla base di un esempio del dot_general
e la specifica del prodotto. Uno dei vincoli per questa operazione ha il seguente aspetto:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
I nomi utilizzati in questa formula provengono da due origini: 1) funzioni globali,
ad esempio dim
, 2) definizioni dei membri dell'elemento del programma corrispondente, ovvero
Input lhs
, lhs_batching_dimensions
, rhs
e rhs_batching_dimensions
definita nella sezione "Input" di dot_general
.
Come detto in precedenza, la sintassi di questa formula è basata su Python con alcune estensioni orientate alla concisione. Per dare un senso alla formula, trasformiamo nella sintassi Python vanilla.
A) In queste formule utilizziamo =
per rappresentare l'uguaglianza, quindi il primo passaggio
per ottenere la sintassi Python è la sostituzione di =
con ==
, come segue:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Inoltre, queste formule supportano i puntini di sospensione (...
) che trasformano le espressioni scalari
in espressioni tensoriali. In poche parole, f(xs...)
significa indicativamente "per ogni
scalare x
nel tensore xs
, calcola un f(x)
scalare e poi restituisce tutti
questi risultati scalari insieme come risultato tensore". Con la sintassi Python "vanilla",
la formula di esempio si trasforma in:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Grazie alle ellissi, spesso è possibile evitare di lavorare al livello
singoli scalari. Tuttavia, in alcuni casi difficili, viene utile anche
puoi usare la sintassi come nella formula start_indices[bi0, ..., :, ..., biN]
della specifica gather
. Al servizio della concisione, non
un formalismo esatto per tradurre tale sintassi in Python vanilla,
si augura che sia comunque comprensibile in modo intuitivo caso per caso.
Facci sapere se alcune formule specifiche sembrano opache e noi cercheremo di
migliorarli.
Inoltre, noterai che le formule usano i puntini di sospensione per espandere tutti i tipi di elenchi, tra cui tensori, elenchi di tensori (che, ad es., possono derivare da una variazione numero di tensori) e così via. Questa è un'altra area in cui non forniamo un numero formalismo (ad esempio, gli elenchi non fanno nemmeno parte del sistema di tipo StableHLO) e si basano invece su una comprensibilità intuitiva.
C) L'ultimo strumento notazionale degno di nota che utilizziamo è la definizione per la trasmissione dei dati. Anche se l'opset StableHLO non supporta la trasmissione implicita, le formule fanno, anche al servizio della concisione. In poche parole, se uno scalare viene utilizzato in un contesto in cui è previsto un tensore, lo scalare viene trasmesso la forma prevista.
Per continuare con l'esempio dot_general
, ecco un altro vincolo:
0 <= lhs_batching_dimensions < rank(lhs)
. Come definito in dot_general
specifica, lhs_batching_dimensions
è un tensore, tuttavia sia 0
che
rank(lhs)
sono scalari. Dopo aver applicato la trasmissione implicita, la formula
diventa [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Se applicata a una determinata operazione dot_general
, questa formula
restituisce un tensore di valori booleani. Quando si utilizzano formule come vincoli,
il vincolo viene bloccato se la formula restituisce true
o un tensore che
ha solo true
elementi.
Nomi
Nelle formule, l'ambito lessicale include: 1) funzioni globali, 2) definizioni dei membri,
3) definizioni locali. L'elenco delle funzioni globali è riportato di seguito. Elenco delle definizioni degli elementi dipende dall'elemento di programma a cui si riferisce la notazione applicata a:
- Per le operazioni, le definizioni dei membri includono i nomi introdotti in "Input" e "Output" sezioni.
- Per tutto il resto, le definizioni dei membri includono le parti strutturali
elemento di programma, chiamato come non-terminali EBNF corrispondenti. La maggior parte di
volta, i nomi di queste parti strutturali si ottengono convertendo
nomi dei non terminali in caso snake (ad es.
IntegerLiteral
=>integer_literal
), ma a volte i nomi vengono abbreviati (ad es.QuantizationStorageType
=>storage_type
), nel qual caso i nomi vengono introdotto in modo esplicito in modo simile agli "Input" / "Output" sezioni in funzione specifiche. - Inoltre, le definizioni dei membri includono sempre
self
per fare riferimento alle elemento di programma corrispondente.
Valori
Quando vengono valutate, le formule funzionano con i seguenti tipi di valori:
1) Value
(valori effettivi, ad es. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
conosce sempre il loro tipo),
2) Placeholder
(valori futuri, ad es. lhs
, rhs
o result
; i loro valori effettivi
non sono ancora noti, solo il loro tipo è noto),
3) Type
(tipi definiti nella sezione "Tipi"),
4) Function
(funzioni globali come definite nella sezione "Funzioni").
A seconda del contesto, i nomi possono riferirsi a valori diversi. Altro
in particolare la "Semantica" sezione relativa alle operazioni (ed equivalenti per altri programmi
) definisce la logica di runtime, in modo che tutti gli input siano disponibili come Value
.
Al contrario, i "Vincoli" per le operazioni (ed equivalenti) definisce
"tempo di compilazione" ovvero qualcosa che in genere viene eseguito prima del runtime,
quindi solo gli input costanti sono disponibili come Value
e altri input sono
disponibile solo come Placeholder
.
Nomi | In "Semantica" | In "Vincoli" |
---|---|---|
Funzioni globali | Function |
Function |
Input costanti | Value |
Value |
Input non costanti | Value |
Placeholder |
Output | Value |
Placeholder |
Definizioni locali | Dipende dalla definizione | Dipende dalla definizione |
Consideriamo un'operazione transpose
di esempio:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Per questa operazione, permutation
è una costante, quindi è disponibile come Value
sia nella semantica che nei vincoli. Al contrario, operand
e result
sono
disponibile come Value
nella semantica, ma solo come Placeholder
nei vincoli.
Funzioni
Costruzione dei tipi
Non esistono funzioni che possono essere utilizzate per costruire tipi. Al contrario,
utilizzare la sintassi del tipo perché di solito è più concisa. Ad es.
(tensor<E>, tensor<E>) -> (tensor<E>)
anziché function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funzioni sui tipi
- Il valore
element_type
è definito sui tipi di tensori e sui tipi di tensori quantizzati. restituisce, rispettivamente,TensorElementType
oQuantizedTensorElementType
parte dei valoriTensorType
oQuantizedTensorType
corrispondenti.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
è un scorciatoia peris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
controlla se il tipox
può essere promosso per digitarey
. Quandox
ey
sonoQuantizedTensorElementType
, la promozione viene applicato solo astorage_type
. Questa versione specifica della promozione è attualmente utilizzato nel contesto del calcolo della riduzione (fai riferimento RFC per ulteriori dettagli).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
è una scorciatoia peris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponibile per tutti di testo. Ad esempio,is_float(x)
restituiscetrue
sex
è unFloatType
. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia peris_type_name(type(x))
.max_value(x: Type) -> Value
restituisce il valore massimo diTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.min_value(x: Type) -> Value
restituisce il valore minimo possibile di unTensorElementType
. Sex
non è unTensorElementType
, restituisceNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponibile per tutti i membri definizionimember_name
di tutti i tipi. Ad esempio,tensor_element_type(x)
restituisce la parteTensorElementType
di unTensorType
corrispondente. Sex
è un valore o un segnaposto, questa funzione è una scorciatoia permember_name(type(x))
. Sex
non è di un tipo con un membro appropriato, oppure un valore o un segnaposto di questo tipo, restituisceNone
.is_empty_algorithm(*args: Type)
controlla se tutti i campi dell'algoritmo dei punti sono impostati aNone
. Questa operazione è necessaria perché gli algoritmi dei punti hanno un'implementazione definita comportamenti predefiniti, quindi specificare un valore predefinito non sarebbe corretto.
Costruzione dei valori
operation_name(*xs: Value | Type) -> Value
. Disponibile per tutte le operazioni. Ad esempio,add(lhs, rhs)
accetta due valori tensorialilhs
erhs
e restituisce l'output della valutazione dell'operazioneadd
con questi input. Per alcune operazioni, ad esempiobroadcast_in_dim
, i tipi di output sono "portante", ovvero necessario per valutare un'operazione. In questo caso, la funzione prende questi tipi come argomenti.
Funzioni sui valori
Sono disponibili tutti gli operatori e le funzioni di Python. Ad es. entrambi abbonamento e slicing le notazioni Python sono disponibili per l'indicizzazione in tensori, tensori quantizzati e tuple.
to_destination_type(x: Value, destination_type: Type) -> Value
definito in tensori e restituisce il valore convertito dix
in base atype(x)
edestination_type
come segue:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Si sta discutendo inizialmente dell'unione di convert
, uniform_quantize
e
Operazioni uniform_dequantize
(#1576).
Dopo l'unione non abbiamo bisogno della funzione di cui sopra e possiamo utilizzare il nome dell'operazione
per convert
.
is_nan(x: Value) -> Value
viene definito sui tensori e restituiscetrue
se tutti gli elementi dix
sonoNaN
ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.is_sorted(x: Value) -> Value
viene definito sui tensori e restituiscetrue
se gli elementi dix
vengono ordinati in ordine crescente rispetto a quelli di livello all'ordine grammaticale dei relativi indici ofalse
in caso contrario. Sex
non è un tensor, restituisceNone
.is_unique(x: Value) -> Value
viene definito sui tensori e restituiscetrue
sex
non ha elementi duplicati ofalse
in caso contrario. Sex
non è un tensore, restituisceNone
.member_name(x: Value) -> Any
è definito per tutte le definizioni dei membrimember_name
di tutti i valori. Ad esempio,real_part(x)
restituisceRealPart
parte di unComplexConstant
corrispondente. Sex
non è un valore con un membro appropriato, restituisceNone
.same(x: Value) -> Value
viene definito sui tensori e restituiscetrue
se elementi dix
sono tutti uguali tra loro ofalse
in caso contrario. Se il tensore non ha elementi, vengono conteggiati come "tutti uguali tra loro", ovvero restituiscetrue
. Sex
non è un tensore, restituisceNone
.split(x: Value, num_results: Value, axis: Value) -> Value
definito in tensori e restituiscenum_results
sezioni dix
lungo l'asseaxis
. Sex
non è un tensore odim(x, axis) % num_results != 0
, restituisceNone
.is_defined_in_parent_scope(x: Value) -> Value
è definito nelle stringhe e restituiscetrue
sex
è il nome di una funzione definita nello stesso ambito come funzione principale dell'operazione pertinente.is_namespaced_op_name(x: Value) -> Value
viene definito nelle stringhe e restituiscetrue
sex
è un nome operativo valido, rispetta la seguente espressione:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Calcoli di forme
axes(x: Value | Placeholder | Type) -> Value
è una scorciatoia perrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
è una scorciatoia pershape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
è una scorciatoia perlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
è definito sui tensori e restituisce gli indicisize(x)
per ilTensorType
corrispondente ordinato in ordine lessicografico crescente, ad es.[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
, Sex
non è un tipo di tensore, un tipo di tensore quantizzato o un valore o un segnaposto di uno di questi tipi, restituisceNone
.rank(x: Value | Placeholder | Type) -> Value
è una scorciatoia persize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
è definito nella sezione "Funzioni sui tipi" tramitemember_name
.size(x: Value | Placeholder | Type) -> Value
è una scorciatoia perreduce(lambda x, y: x * y, shape(x))
.
Calcoli di quantizzazione
def baseline_element_type(x: Value | Placeholder | Type) -> Type
è un scorciatoia perelement_type(baseline_type(x))
.Il valore
baseline_type
è definito sui tipi di tensori e sui tipi di tensori quantizzati. le trasforma in una "base di riferimento", ovvero di un tipo con la stessa forma ma con i parametri di quantizzazione del tipo di elemento vengono reimpostati sui valori predefiniti. Questo è utilizzato come pratico trucco per confrontare tipi di tensori sia tensori che quantizzati in modo uniforme, cosa necessaria abbastanza spesso. Per i tipi quantizzati, questo consente Confrontando i tipi ignorando i parametri di quantizzazione, ovveroshape
,storage_type
,expressed_type
,storage_min
,storage_max
equantization_dimension
(per il tipo quantizzato per asse) deve corrispondere, mascales
ezero points
potrebbero variare.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
viene definito su tipi di tensori quantizzati e li trasforma in i tipi di tensori a virgola mobile. Ciò avviene mediante la conversione di elementi quantizzati che rappresentano i valori interi del tipo di archiviazione in valori in virgola mobile del tipo espresso utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
è definito su tipi di tensori a virgola mobile e li trasforma in di tensori quantizzati. Ciò avviene mediante la conversione dei valori in virgola mobile del tipo espresso nei valori interi corrispondenti del tipo di archiviazione utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
viene utilizzato per specificare i calcoli a livello di elemento su tensori quantizzati. Dequantizza, cioè trasforma gli elementi quantizzati nei loro i tipi espressi, quindi esegue un'operazione e poi la quantizza, ovvero trasforma i risultati nei rispettivi tipi di archiviazione. Al momento questa funzione per la quantizzazione per tensore. La quantizzazione per asse è in corso (N. 1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
viene utilizzato per specificare la quantizzazione basata solo sulla ponderazione per un'operazione ibrida che accetta lh in rappresentazione in virgola mobile e RP nei tipi quantizzati. it dequantizza gli input quantizzati nei loro tipi espressi ed esegue il calcolo in virgola mobile. Tipo di elemento del tensore lhs in virgola mobile e tipo espresso di URL quantistici tensore deve essere identico.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calcoli a griglia
cross_partition(replica_groups: Value) -> Value
. Vedi "cross_replica" sezione precedente.cross_replica(replica_groups: Value) -> Value
. Vedi "cross_replica" sezione precedente.cross_replica_and_partition(replica_groups: Value) -> Value
. Consulta le "cross_replica_and_partition" sezione precedente.flattened_ids(replica_groups: Value) -> Value
. Vedi "flattened_id" sezione precedente.
Dinamismo
I valori HLO stabili possono avere dimensioni di dimensioni dinamiche, ad esempio tensor<?xi64>
.
Tuttavia, i valori StableHLO non possono avere un numero dinamico di dimensioni (senza
dinamismo, ad es. tensor<*xi64>
). Gli operandi e i risultati possono utilizzare
dimensioni, anche se ci sono vincoli sulle dimensioni. I vincoli saranno
verificati in modo statico, se possibile, altrimenti vengono differiti al runtime e
errate corrispondenze daranno un comportamento indefinito. Di seguito sono riportati gli esempi.
Mancata corrispondenza della forma per le operazioni una tantum degli elementi
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un programma del genere è insolito, perché non è comune conoscere la forma del
ma non la forma dell'input. Tuttavia, questo è un file StableHLO valido
. Non è possibile convalidare in modo statico l'operazione abs
in questo
perché la forma esatta dell'operando non è nota. Tuttavia, le forme
sono sicuramente compatibili e puoi controllarli in modo statico: i risultati di ?
potrebbero
sia 2
in fase di runtime. Non ci sono problemi. Tuttavia, ?
potrebbe
risultano essere anche altri numeri interi, nel qual caso il comportamento non è definito.
Tieni presente che se una dimensione è dinamica nel risultato, non possono essere un comportamento indefinito. In effetti, non c'è nessun "previsto" dimensione, perciò non può esserci non corrispondente.
Mancata corrispondenza della forma per le operazioni binarie elementari
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Quando si tratta di operazioni binarie elementari, le forme degli input e il risultato deve essere accettato in fase di runtime. Al momento della compilazione, le dimensioni statiche devono essere uguali, altrimenti devono solo essere compatibili. Se qualsiasi dimensione è dinamica negli input, la dimensione potrebbe essere indefinita un comportamento durante il runtime, perché la dimensione dinamica potrebbe non corrispondere dimensioni nell'altro operando (statico o dinamico). Se tutti gli input sono statico, la differenza è che il risultato sia dinamico o meno: in modo statico le dimensioni note vengono controllate in modo statico, mentre le dimensioni dinamiche imporre vincoli.
Mancata corrispondenza delle forme per le operazioni che assumono la forma di output come operando
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
I valori nell'operando di forma in fase di runtime devono corrispondere alla forma del risultato.
altrimenti il comportamento è indefinito. Ciò significa che in fase di runtime %arg0
deve avere un
valore di dense<[3, 4]> : tensor<2xi32>
. Se l'operando di forma è costante,
possono essere verificati in modo statico. Se la forma dei risultati è completamente dinamica,
non possono essere errate.