StableHLO è un insieme di operazioni per operazioni di alto livello (HLO) nei modelli di machine learning (ML). StableHLO funge da livello di portabilità tra diversi framework ML e compilatori ML: i framework ML che producono programmi StableHLO sono compatibili con i compilatori ML che utilizzano programmi StableHLO.
Il nostro obiettivo è semplificare e accelerare lo sviluppo di ML creando una maggiore interoperabilità tra vari framework ML (come TensorFlow, JAX e PyTorch) e compilatori ML (come XLA e IREE). A questo scopo, questo documento fornisce una specifica per il linguaggio di programmazione StableHLO.
Questa specifica contiene tre sezioni principali. Innanzitutto, la sezione Programmi descrive la struttura dei programmi StableHLO che sono costituiti da funzioni StableHLO che a loro volta sono costituite da operazioni StableHLO. All'interno di questa struttura, la sezione Ops specifica la semantica delle singole operazioni. La sezione Esecuzione fornisce la semantica per tutte queste operazioni eseguite insieme all'interno di un programma. Infine, la sezione Notazione illustra la notazione utilizzata in tutta la specifica.
Per visualizzare le specifiche di una release precedente di StableHLO, apri il repository nella release taggata di tuo interesse. Ad esempio, le specifiche di StableHLO v0.19.0. Per visualizzare le modifiche apportate a ogni aggiornamento della versione secondaria di StableHLO, consulta il log delle versioni in VhloDialect.td.
Programmi
Program ::= {Func}
I programmi StableHLO sono costituiti da un numero arbitrario di funzioni StableHLO.
Di seguito è riportato un programma di esempio con una funzione @main che ha 3 input
(%image, %weights e %bias) e 1 output. Il corpo della funzione
ha 6 operazioni.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funzioni
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Le funzioni StableHLO (chiamate anche funzioni con nome) hanno un identificatore, input/output e un corpo. In futuro, prevediamo di introdurre metadati aggiuntivi per le funzioni per ottenere una migliore compatibilità con HLO (#425, #626, #740, #744).
Identificatori
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Gli identificatori StableHLO sono simili agli identificatori di molti linguaggi di programmazione, con due peculiarità: 1) tutti gli identificatori hanno sigilli che distinguono i diversi tipi di identificatori, 2) gli identificatori di valori possono essere completamente numerici per semplificare la generazione di programmi StableHLO.
Tipi
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
I tipi StableHLO sono suddivisi in tipi di valori (chiamati anche tipi di prima classe) che rappresentano i valori StableHLO e tipi non di valori che descrivono altri elementi del programma. I tipi StableHLO sono simili a quelli di molti linguaggi di programmazione, con la principale peculiarità di essere specifici per il dominio, il che comporta alcuni risultati insoliti (ad es. i tipi scalari non sono tipi di valore).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
I tipi di tensori rappresentano i tensori, ovvero array multidimensionali. Hanno una
forma e un tipo di elemento, dove una forma rappresenta dimensioni non negative o
sconosciute delle dimensioni in ordine crescente delle
dimensioni corrispondenti (chiamate anche assi) numerate da 0 a R-1. Il
numero di dimensioni R è chiamato rank. Ad esempio, tensor<2x3xf32> è
un tipo di tensore con forma 2x3 e tipo di elemento f32. Ha due dimensioni
(o, in altre parole, due assi) - la dimensione 0 e la dimensione 1 - le cui dimensioni
sono 2 e 3. Il suo ranking è 2.
Le forme possono essere parzialmente o completamente sconosciute (dinamiche), ad esempio tensor<?x2xf64> è parzialmente sconosciuta e tensor<?x?xf64> è completamente sconosciuta. Le dimensioni dinamiche
sono rappresentate utilizzando un ?. Non è possibile rimuovere la classificazione delle forme.
In futuro, prevediamo di esplorare l'estensione dei tipi di tensori oltre le dimensioni e i tipi di elementi, ad esempio per includere layout (#629) e sparsità (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nome | Tipo | Vincoli |
|---|---|---|
storage_type |
tipo di numero intero | (C1-C3), (C8) |
storage_min |
costante intera | (C1), (C3), (C7) |
storage_max |
costante intera | (C2), (C3), (C7) |
expressed_type |
tipo in virgola mobile | (C4) |
quantization_dimension |
costante intera facoltativa | (C10-C12) |
scales |
numero variabile di costanti in virgola mobile | (C4-C6), (C9), (C10), (C13) |
zero_points |
numero variabile di costanti intere | (C7-C9) |
I tipi di elementi quantizzati rappresentano valori interi di un tipo di archiviazione nell'intervallo da storage_min a storage_max (inclusi) che corrispondono a valori in virgola mobile di un tipo espresso. Per un determinato valore intero i,
il valore in virgola mobile corrispondente f può essere calcolato come
f = (i - zero_point) * scale, dove scale e zero_point sono chiamati
parametri di quantizzazione. storage_min e storage_max sono facoltativi
nella grammatica, ma hanno valori predefiniti di min_value(storage_type) e
max_value(storage_type) rispettivamente. I tipi di elementi quantizzati presentano i seguenti vincoli:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Se
is_empty(quantization_dimension), allorasize(scales) = 1. - (C11)
0 <= quantization_dimension.
Al momento, QuantizationScale è una costante in virgola mobile, ma esiste un forte interesse per le scale basate su numeri interi, rappresentate con moltiplicatori e spostamenti. Abbiamo in programma di esplorare questa possibilità nel prossimo futuro
(#1404).
È in corso una discussione sulla semantica di QuantizationZeroPoint,
inclusi il tipo, i valori e la possibilità di avere uno o
più punti zero in un tipo di tensore quantizzato. In base ai risultati di questa discussione, la specifica relativa a zero punti potrebbe cambiare in futuro (#1405).
Un'altra discussione in corso riguarda la semantica di QuantizationStorageMin
e QuantizationStorageMax per determinare se debbano essere imposti vincoli
su questi valori e sui valori dei tensori quantizzati
(#1406).
Infine, stiamo pianificando di esplorare la rappresentazione di scale e punti zero sconosciuti, in modo simile a come stiamo pianificando di esplorare la rappresentazione di dimensioni sconosciute (#1407).
I tipi di tensori quantizzati rappresentano i tensori con elementi quantizzati. Questi tensori sono esattamente uguali ai tensori normali, tranne per il fatto che i loro elementi hanno tipi di elementi quantizzati, anziché tipi di elementi normali.
Nei tensori quantizzati, la quantizzazione può essere per tensore, ovvero con un scale e un zero_point per l'intero tensore, oppure per asse, ovvero con più scales e zero_points, una coppia per ogni sezione di una determinata dimensione quantization_dimension. Più formalmente, in un tensore t
con quantizzazione per asse, ci sono dim(t, quantization_dimension) sezioni
di quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :],
ecc. Tutti gli elementi nella i-esima sezione utilizzano scales[i] e zero_points[i] come
parametri di quantizzazione. I tipi di tensore quantizzato hanno i seguenti
vincoli:
- Per la quantizzazione per tensore:
- Nessun vincolo aggiuntivo.
- Per la quantizzazione per asse:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
I tipi di token rappresentano i token, ovvero valori opachi prodotti e utilizzati da alcune operazioni. I token vengono utilizzati per imporre l'ordine di esecuzione delle operazioni come descritto nella sezione Esecuzione.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
I tipi di buffer rappresentano i buffer. Ad esempio, in XLA i buffer sono array multidimensionali con spazio di archiviazione coerente. Analogamente ai tipi di tensore,
i tipi di buffer hanno una forma e un tipo di elemento, dove una forma rappresenta
dimensioni non negative o sconosciute delle dimensioni in ordine crescente delle
dimensioni corrispondenti (chiamate anche assi) numerate da 0
a R-1. Il numero di dimensioni R è chiamato rank. Ad esempio,
memref<2x3xf32> è un tipo di buffer con forma 2x3 e tipo di elemento f32. Ha due dimensioni (o, in altre parole, due assi): la dimensione 0 e la dimensione 1, le cui dimensioni sono 2 e 3. Il suo ranking è 2.
I buffer possono essere allocati utilizzando un valore compreso tra custom_call e CreateBuffer o Pin e
deallocati tramite un valore compreso tra custom_call e Unpin. Solo gli operatori custom_call possono leggere e
scrivere i contenuti all'interno dei buffer. Per maggiori dettagli, consulta custom_call.
I tipi di tuple rappresentano le tuple, ovvero elenchi eterogenei. Le tuple sono una funzionalità legacy
che esiste solo per la compatibilità con HLO. In HLO, le tuple vengono
utilizzate per rappresentare input e output variadici. In StableHLO, gli input e gli output variadici sono supportati in modo nativo e l'unico utilizzo delle tuple in StableHLO è quello di rappresentare in modo completo l'ABI HLO, in cui, ad esempio, T, tuple<T> e tuple<tuple<T>> possono essere materialmente diversi a seconda di una particolare implementazione. In futuro, prevediamo di apportare modifiche all'ABI HLO
che potrebbero consentirci di rimuovere i tipi di tupla da StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
I tipi di elementi rappresentano gli elementi dei tipi di tensore. A differenza di molti linguaggi di programmazione, questi tipi non sono di prima classe in StableHLO. Ciò significa che
i programmi StableHLO non possono rappresentare direttamente i valori di questi tipi (di conseguenza,
è idiomatico rappresentare i valori scalari di tipo T con tensori
di tipo tensor<T> di dimensione 0).
- Il tipo booleano rappresenta i valori booleani
trueefalse. - I tipi interi possono essere con segno (
si) o senza segno (ui) e avere una delle larghezze in bit supportate (2,4,8,16,32o64). I tipisiNcon segno rappresentano valori interi da-2^(N-1)a2^(N-1)-1inclusi, mentre i tipiuiNsenza segno rappresentano valori interi da0a2^N-1inclusi. - I tipi a virgola mobile possono essere uno dei seguenti:
f8E3M4,f8E4M3ef8E5M2numeri in virgola mobile a 8 bit che seguono le convenzioni IEEE-754.f8E4M3FNef8E5M2corrispondenti rispettivamente alle codificheE4M3eE5M2del formato FP8 descritto in Formati FP8 per il deep learning.- Tipi
f8E4M3FNUZef8E5M2FNUZcorrispondenti alle codificheE4M3eE5M2dei formati FP8 descritti in Formati numerici a 8 bit per reti neurali profonde. - Tipo
f8E4M3B11FNUZcorrispondente alla codificaE4M3dei formati FP8 descritti in Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks. - Tipo
bf16corrispondente al formatobfloat16descritto in BFloat16: il segreto delle prestazioni elevate su Cloud TPU. f16,f32ef64corrispondenti rispettivamente ai formatibinary16("mezza precisione"),binary32("precisione singola") ebinary64("doppia precisione") descritti nello standard IEEE 754.- Il tipo
tf32corrisponde al formato TensorFloat32 ed è supportato in modo limitato in StableHLO. f4E2M1FN,f6E2M3FN,f6E3M2FNef8E8M0FNUMX (microscaling) descritti nella specifica dei formati di microscaling OCP.
- I tipi complessi rappresentano valori complessi che hanno una parte reale
e una parte immaginaria dello stesso tipo di elemento. I tipi complessi
supportati sono
complex<f32>(entrambe le parti sono di tipof32) ecomplex<f64>(entrambe le parti sono di tipof64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
I tipi di funzioni rappresentano sia le funzioni con nome che quelle anonime. Hanno
tipi di input (l'elenco dei tipi sul lato sinistro di ->) e tipi di output
(l'elenco dei tipi sul lato destro di ->). In molti linguaggi di programmazione, i tipi di funzioni sono di prima classe, ma non in StableHLO.
StringType ::= 'string'
Il tipo stringa rappresenta sequenze di byte. A differenza di molti linguaggi di programmazione, il tipo di stringa non è di prima classe in StableHLO e viene utilizzato solo per specificare metadati statici per gli elementi del programma.
Operazioni
Le operazioni StableHLO (chiamate anche ops) rappresentano un insieme chiuso di operazioni di alto livello nei modelli di machine learning. Come discusso in precedenza, la sintassi di StableHLO è fortemente ispirata a MLIR, che non è necessariamente l'alternativa più ergonomica, ma è probabilmente la più adatta allo scopo di StableHLO di creare una maggiore interoperabilità tra i framework ML e i compilatori ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Le operazioni StableHLO (chiamate anche ops) hanno un nome,
input/output e una firma. Il nome è composto dal prefisso stablehlo. e da
un mnemonico che identifica in modo univoco una delle operazioni supportate. Di seguito è riportato un elenco completo di tutte le operazioni supportate.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Le operazioni utilizzano input e producono output. Gli input sono suddivisi in
valori di input (calcolati durante l'esecuzione), funzioni di input (fornite
staticamente, perché in StableHLO le funzioni non sono valori di prima classe) e
attributi di input (forniti anche staticamente). Il tipo di input e output
utilizzati e prodotti da un'operazione dipende dal relativo mnemonico. Ad esempio, l'operazione add
utilizza 2 valori di input e produce 1 valore di output. Al contrario, l'operatore
select_and_scatter utilizza 3 valori di input, 2 funzioni di input e
3 attributi di input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Le funzioni di input (chiamate anche funzioni anonime) sono molto simili alle funzioni con nome, tranne per il fatto che: 1) non hanno un identificatore (da cui il nome "anonime "), 2) non dichiarano i tipi di output (i tipi di output vengono dedotti dall'operatore return all'interno della funzione).
La sintassi delle funzioni di input include una parte attualmente inutilizzata (vedi la produzione Unused sopra) che serve per la compatibilità con MLIR. In MLIR,
esiste un concetto più generale di "regioni" che possono avere più "blocchi"
di operazioni collegate tra loro tramite operazioni di salto. Questi blocchi hanno ID che corrispondono
alla produzione di Unused, in modo che possano essere distinti l'uno dall'altro.
StableHLO non ha operazioni di salto, quindi la parte corrispondente della sintassi MLIR non viene utilizzata (ma è ancora presente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Gli attributi di input hanno un nome e un valore che è una delle costanti supportate. Sono il modo principale per specificare i metadati statici per gli elementi del programma. Ad esempio, l'operatore concatenate utilizza l'attributo dimension per
specificare la dimensione lungo la quale vengono concatenati i valori di input. Analogamente,
l'operatore slice utilizza più attributi come start_indices e limit_indices
per specificare i limiti utilizzati per suddividere il valore di input.
Al momento, i programmi StableHLO in natura a volte contengono attributi che non sono descritti in questo documento. In futuro, prevediamo di incorporare questi attributi nell'opset StableHLO o di vietarne la visualizzazione nei programmi StableHLO. Nel frattempo, ecco l'elenco di questi attributi:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- Metadati sulla posizione (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma dell'operazione è costituita dai tipi di tutti i valori di input (l'elenco dei tipi sul lato sinistro di ->) e dai tipi di tutti i valori di output (l'elenco dei tipi sul lato destro di ->). A rigor di termini, i tipi di input sono ridondanti e anche i tipi di output lo sono quasi sempre (perché per la maggior parte delle operazioni StableHLO, i tipi di output possono essere dedotti dagli input). Tuttavia, la firma
dell'operazione fa deliberatamente parte della sintassi di StableHLO per la compatibilità con MLIR.
Di seguito è riportato un esempio di operazione il cui mnemonico è select_and_scatter. Utilizza 3 valori di input (%operand, %source e %init_value), 2 funzioni di input e 3 attributi di input (window_dimensions, window_strides e padding).
Tieni presente che la firma dell'operazione include solo i tipi dei valori di input
(ma non i tipi di funzioni e attributi di input forniti inline).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Costanti
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Le costanti StableHLO hanno un valore letterale e un tipo che insieme rappresentano
un valore StableHLO. In genere, il tipo fa parte della sintassi della costante, tranne
quando è non ambiguo (ad es. una costante booleana ha inequivocabilmente il tipo i1,
mentre una costante intera può avere più tipi possibili).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Le costanti booleane rappresentano i valori booleani true e false. Le costanti booleane
hanno tipo i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Le costanti intere rappresentano valori interi tramite stringhe che utilizzano la notazione decimale o esadecimale. Altre basi, ad esempio binaria o ottale, non sono supportate. Le costanti intere presentano i seguenti vincoli:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Le costanti in virgola mobile rappresentano valori in virgola mobile tramite stringhe che utilizzano la notazione decimale o scientifica. Inoltre, la notazione esadecimale può essere utilizzata per specificare direttamente i bit sottostanti nel formato in virgola mobile del tipo corrispondente. Le costanti in virgola mobile presentano i seguenti vincoli:
- (C1) Se viene utilizzata una notazione non esadecimale,
is_wellformed(float_literal, float_type). - (C2) Se viene utilizzata la notazione esadecimale,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Le costanti complesse rappresentano valori complessi utilizzando elenchi di una parte reale
(prima) e una parte immaginaria (seconda). Ad esempio,
(1.0, 0.0) : complex<f32> rappresenta 1.0 + 0.0i e
(0.0, 1.0) : complex<f32> rappresenta 0.0 + 1.0i. L'ordine in cui queste
parti vengono archiviate in memoria è definito dall'implementazione. Le costanti complesse
hanno i seguenti vincoli:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Le costanti tensoriali rappresentano i valori tensoriali utilizzando elenchi nidificati specificati tramite
la notazione NumPy. Ad esempio, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
rappresenta un valore tensore con il seguente mapping dagli indici agli elementi:
{0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5,
{1, 2} => 6. L'ordine in cui questi elementi vengono archiviati in memoria è definito dall'implementazione. Le costanti tensoriali presentano i seguenti vincoli:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), dove:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), dove:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- altrimenti,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Le costanti tensoriali quantizzate rappresentano i valori tensoriali quantizzati utilizzando la stessa notazione delle costanti tensoriali, con gli elementi specificati come costanti del tipo di archiviazione. Le costanti tensoriali quantizzate presentano i seguenti vincoli:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
I valori letterali stringa sono costituiti da byte specificati utilizzando caratteri ASCII e
sequenze di escape. Sono indipendenti dalla codifica, quindi l'interpretazione di questi
byte è definita dall'implementazione. I valori letterali stringa hanno il tipo string.
Operazioni
abs
Semantica
Esegue l'operazione abs elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi con segno: modulo intero.
- Per i numeri in virgola mobile:
absda IEEE-754. - Per i numeri complessi: modulo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(abs, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo intero con segno, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1-C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero con segno o con rappresentazione in virgola mobile o tensore quantizzato per tensore | (C1-C2) |
Vincoli
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)è definito come:complex_element_type(element_type(operand))seis_complex(operand).baseline_element_type(operand)altrimenti.
Esempi
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
add
Semantica
Esegue l'addizione elemento per elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per i numeri interi: addizione di numeri interi.
- Per i numeri in virgola mobile:
additionda IEEE-754. - Per i numeri complessi: addizione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(add, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato | (C1-C6) |
| (I2) | rhs |
tensore o tensore quantizzato | (C1-C5), (C7) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1-C7) |
Vincoli
- Se l'operazione utilizza tensori non quantizzati:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Se l'operazione utilizza tensori quantizzati:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Se
is_per_axis_quantized(lhs), alloraquantization_dimension(lhs) = quantization_dimension(result). - (C7) Se
is_per_axis_quantized(rhs), alloraquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
Semantica
Garantisce che le operazioni che producono inputs vengano eseguite prima di qualsiasi
operazione che dipende da result. L'esecuzione di questa operazione non fa nulla,
esiste solo per stabilire le dipendenze dei dati da result a inputs.
Input
| Etichetta | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
numero variadico di token |
Output
| Nome | Tipo |
|---|---|
result |
token |
Esempi
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, concatena i valori
dei tensori operands di ogni processo lungo all_gather_dim e produce
tensori results.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Successivamente, all'interno di ogni process_group:
operands...@receiver = [operand@sender for sender in process_group]per tuttireceiverinprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)per tuttiprocessinprocess_group.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operands |
numero variabile di tensori o tensori quantizzati per tensore | (C1), (C6) |
| (I2) | all_gather_dim |
costante di tipo si64 |
(C1), (C6) |
| (I3) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C2-C4) |
| (I4) | channel_id |
costante di tipo si64 |
(C5) |
| (I5) | use_global_device_ids |
costante di tipo i1 |
(C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C6) |
Vincoli
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)è definito come:num_replicasse viene utilizzatocross_replica.num_replicasse viene utilizzatocross_replica_and_partition.num_processesse viene utilizzatoflattened_ids.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Se
use_global_device_ids = true, allorachannel_id > 0. - (C6)
type(results...) = type(operands...)tranne:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, applica una funzione di riduzione
computation ai valori dei tensori operands di ogni processo
e produce tensori results.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Successivamente, all'interno di ogni process_group:
results...@process[result_index] = exec(schedule)per alcuni alberi binarischeduledove:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
scheduleè un albero binario definito dall'implementazione la cui attraversamento in ordine èto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operands |
numero variabile di tensori o tensori quantizzati per tensore | (C5), (C6) |
| (I2) | replica_groups |
numero variadico di costanti tensoriali unidimensionali di tipo si64 |
(C1-C3) |
| (I3) | channel_id |
costante di tipo si64 |
(C4) |
| (I4) | use_global_device_ids |
costante di tipo i1 |
(C4) |
| (I5) | computation |
funzione | (C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C6-C7) |
Vincoli
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)è definito come:num_replicasse viene utilizzatocross_replica.num_replicasse viene utilizzatocross_replica_and_partition.num_processesse viene utilizzatoflattened_ids.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Se
use_global_device_ids = true, allorachannel_id > 0. - (C5)
computationè di tipo(tensor<E>, tensor<E>) -> (tensor<E>)doveis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, divide i valori dei tensori operands lungo split_dimension in parti, distribuisce le parti divise tra i processi, concatena le parti distribuite lungo concat_dimension e produce tensori results.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(replica_groups)sechannel_id <= 0.cross_partition(replica_groups)sechannel_id > 0.
Successivamente, all'interno di ogni process_group:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)per tutti isenderinprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]dovereceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operands |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C3), (C9) |
| (I2) | split_dimension |
costante di tipo si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
costante di tipo si64 |
(C3), (C9) |
| (I4) | split_count |
costante di tipo si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C5-C8) |
| (I6) | channel_id |
costante di tipo si64 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C9) |
Vincoli
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)è definito come:num_replicasse viene utilizzatocross_replica.num_partitionsse viene utilizzatocross_partition.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...)tranne sesplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
e
Semantica
Esegue l'operazione AND elemento per elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per i numeri interi: AND bit a bit.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo booleano o intero | (C1) |
| (I2) | rhs |
tensore di tipo booleano o intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
Semantica
Esegue l'operazione atan2 elemento per elemento sui tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
atan2da IEEE-754. - Per i numeri complessi: atan2 complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Esempi
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantica
Calcola i gradienti di diversi input di batch_norm_training mediante la retropropagazione
da grad_output e produce i tensori grad_operand, grad_scale e grad_offset. Più formalmente, questa operazione può essere espressa come una scomposizione in operazioni StableHLO esistenti utilizzando la sintassi Python come segue:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C3), (C5) |
| (I2) | scale |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4), (C5) |
| (I3) | mean |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4) |
| (I4) | variance |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4) |
| (I5) | grad_output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
| (I6) | epsilon |
costante di tipo f32 |
|
| (I7) | feature_index |
costante di tipo si64 |
(C1), (C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
grad_operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C3) |
grad_scale |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4) |
grad_offset |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4) |
Vincoli
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scaleegrad_offsethanno lo stessobaseline_element_type. - (C3)
operand,grad_outputegrad_operandhanno la stessa forma. - (C4)
scale,mean,variance,grad_scaleegrad_offsethanno la stessa forma. - (C5)
size(scale) = dim(operand, feature_index).
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantica
Normalizza il tensore operand in tutte le dimensioni, tranne la dimensione feature_index, e produce un tensore result. Più formalmente, questa operazione può essere espressa come una scomposizione in operazioni StableHLO esistenti utilizzando la sintassi Python nel seguente modo:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1-C7) |
| (I2) | scale |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C3) |
| (I3) | offset |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C4) |
| (I4) | mean |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C5) |
| (I5) | variance |
Tensore unidimensionale di tipo in virgola mobile o quantizzato per tensore | (C2), (C6) |
| (I6) | epsilon |
costante di tipo f32 |
|
| (I7) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C2), (C7) |
Vincoli
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceeresulthanno lo stessobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantica
Calcola la media e la varianza in tutte le dimensioni, ad eccezione della dimensione feature_index, e normalizza il tensore operand producendo i tensori output, batch_mean e batch_var. Più formalmente, questa operazione può essere espressa come una
scomposizione in operazioni StableHLO esistenti utilizzando la sintassi Python come
segue:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Per i tipi quantizzati, esegue
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
| (I2) | scale |
Tensore unidimensionale di numeri in virgola mobile o quantizzati per tensore | (C2), (C3) |
| (I3) | offset |
Tensore unidimensionale di numeri in virgola mobile o quantizzati per tensore | (C2), (C4) |
| (I4) | epsilon |
costante di tipo f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
costante di tipo si64 |
(C1), (C3-C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C7) |
batch_mean |
Tensore unidimensionale di numeri in virgola mobile o quantizzati per tensore | (C2), (C5) |
batch_var |
Tensore unidimensionale di numeri in virgola mobile o quantizzati per tensore | (C2), (C6) |
Vincoli
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_vareoutputhanno lo stessobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Esempi
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantica
Esegue un'operazione di bitcast sul tensore operand e produce un tensore result
in cui i bit dell'intero tensore operand vengono reinterpretati utilizzando il
tipo del tensore result.
Più formalmente, dati E = element_type(operand), E' = element_type(result)
e R = rank(operand):
- Se
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Se
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Se
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits restituisce la rappresentazione in memoria di un determinato valore e il suo comportamento
è definito dall'implementazione perché la rappresentazione esatta dei tensori è
definita dall'implementazione, così come la rappresentazione esatta dei tipi di elementi.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1-C2) |
Vincoli
- (C1) Dato
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result)eR = rank(operand):- Se
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Se
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)per tutti i0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- Se
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)per tutti i0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- Se
- (C2) Se
is_complex(operand) or is_complex(result), allorais_complex(operand) and is_complex(result).
Esempi
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantica
Espande le dimensioni e/o il rango di un tensore di input duplicando i dati
nel tensore operand e produce un tensore result. In termini più formali,
result[result_index] = operand[operand_index] dove per tutti i d in
axes(operand):
operand_index[d] = 0sedim(operand, d) = 1.operand_index[d] = result_index[broadcast_dimensions[d]]altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C2-C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C6) |
Vincoli
- (C1)
element_type(result)è fornito da:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), tranne chequantization_dimension(operand),scales(operand)ezero_points(operand)potrebbero differire daquantization_dimension(result),scales(result)ezero_points(result)rispettivamente.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Per tutti i
dinaxes(operand):dim(operand, d) = 1odim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Se
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Se
dim(operand, quantization_dimension(operand)) = 1, allorascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Esempi
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
richiesta
Semantica
Produce l'output dell'esecuzione di esattamente una funzione da branches
a seconda del valore di index. In termini più formali, result = selected_branch()
dove:
selected_branch = branches[index]se0 <= index < size(branches).selected_branch = branches[-1]altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | index |
Tensore unidimensionale di tipo si32 |
|
| (I2) | branches |
numero variabile di funzioni | (C1-C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori, tensori quantizzati o token | (C4) |
Vincoli
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Esempi
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semantica
Esegue l'operazione di radice cubica elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
rootn(x, 3)da IEEE-754. - Per i numeri complessi: radice cubica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(cbrt, operand, type(result))
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantica
Esegue l'operazione ceil elemento per elemento del tensore operand e produce un tensore result.
Implementa l'operazione roundToIntegralTowardPositive dalla specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(ceil, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semantica
Calcola la decomposizione di Cholesky di un batch di matrici.
Più formalmente, per tutti i i in index_space(result),
result[i0, ..., iR-3, :, :] è una decomposizione di Cholesky di
a[i0, ..., iR-3, :, :], sotto forma di matrice triangolare inferiore
(se lower è true) o triangolare superiore (se lower è false).
I valori di output nel triangolo opposto, ovvero il triangolo superiore stretto o
il triangolo inferiore stretto, sono definiti dall'implementazione.
Se esiste i in cui la matrice di input non è una matrice hermitiana definita positiva, il comportamento non è definito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | a |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
| (I2) | lower |
costante di tipo i1 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Esempi
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
clampare
Semantica
Blocca ogni elemento del tensore operand tra un valore minimo e massimo e produce un tensore result. Più formalmente, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element),
dove min_element = rank(min) = 0 ? min[] : min[result_index],
max_element = rank(max) = 0 ? max[] : max[result_index]. Per i tipi quantizzati,
esegue dequantize_op_quantize(clamp, min, operand, max, type(result)).
L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro prevediamo di rimuovere il supporto per i numeri complessi per questa operazione (#560).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | min |
tensore o tensore quantizzato per tensore | (C1), (C3) |
| (I2) | operand |
tensore o tensore quantizzato per tensore | (C1-C4) |
| (I3) | max |
tensore o tensore quantizzato per tensore | (C2), (C3) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C4) |
Vincoli
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Esempi
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, invia il valore del tensore operand dal processo di origine ai processi di destinazione e genera un tensore result.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(replica_groups)sechannel_id <= 0.cross_partition(replica_groups)sechannel_id > 0.
Successivamente, result@process viene fornito da:
operand@process_groups[i, 0]se esiste unitale che la procedura sia inprocess_groups[i].broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C3) |
| (I2) | replica_groups |
numero variadico di costanti tensoriali unidimensionali di tipo si64 |
(C1), (C2) |
| (I3) | channel_id |
costante di tipo si64 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3) |
Vincoli
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < NdoveNè definito come:num_replicasse viene utilizzatocross_replica.num_partitionsse viene utilizzatocross_partition.
- (C3)
type(result) = type(operand).
Esempi
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, invia il valore del tensore operand dal processo di origine al processo di destinazione e produce un tensore result.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(source_target_pairs)sechannel_id <= 0.cross_partition(source_target_pairs)sechannel_id > 0.
Successivamente, result@process viene fornito da:
operand@process_groups[i, 0], se esiste unitale cheprocess_groups[i, 1] = process.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C5) |
| (I2) | source_target_pairs |
Costante tensore bidimensionale di tipo si64 |
(C1-C4) |
| (I3) | channel_id |
costante di tipo si64 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, doveNè definito come:num_replicasse viene utilizzatocross_replica.num_partitionsse viene utilizzatocross_partition.
- (C5)
type(result) = type(operand).
Esempi
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
confronta
Semantica
Esegue il confronto elemento per elemento dei tensori lhs e rhs in base a
comparison_direction e compare_type e produce un tensore result.
I valori di comparison_direction e compare_type hanno la seguente
semantica:
Per i tipi di elementi booleani e interi:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
Per i tipi di elementi in virgola mobile con compare_type = FLOAT, l'operazione implementa
le seguenti operazioni IEEE-754:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
Per i tipi di elementi in virgola mobile con compare_type = TOTALORDER, l'operazione
utilizza la combinazione delle operazioni totalOrder e compareQuietEqual di
IEEE-754.
Per i tipi di elementi complessi, il confronto lessicografico delle coppie (real, imag) viene
eseguito utilizzando comparison_direction e compare_type forniti.
L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro prevediamo di rimuovere il supporto per i numeri complessi quando comparison_direction è GE, GT, LE o LT
(#560).
Per i tipi quantizzati, esegue dequantize_compare(lhs, rhs,
comparison_direction).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1-C3) |
| (I2) | rhs |
tensore o tensore quantizzato per tensore | (C1-C2) |
| (I3) | comparison_direction |
enum di EQ, NE, GE, GT, LE e LT |
|
| (I4) | compare_type |
enum di FLOAT, TOTALORDER, SIGNED e UNSIGNED |
(C3) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo booleano | (C2) |
Vincoli
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typeè definito come:SIGNEDseis_signed_integer(element_type(lhs)).UNSIGNEDseis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOAToTOTALORDERseis_float(element_type(lhs)).FLOATseis_complex(element_type(lhs)).
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
complesso
Semantica
Esegue la conversione elemento per elemento in un valore complesso da una coppia di valori reali e
immaginari, lhs e rhs, e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo f32 o f64 |
(C1-C3) |
| (I2) | rhs |
tensore di tipo f32 o f64 |
(C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo complesso | (C2), (C3) |
Vincoli
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)è di tipocomplex<E>doveE = element_type(lhs).
Esempi
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
composito
Semantica
Incapsula un'operazione composta da altre operazioni StableHLO,
che prende inputs e composite_attributes e produce results. La semantica dell'operazione viene implementata dall'attributo decomposition. L'operatore
composite può essere sostituito con la sua decomposizione senza modificare la semantica del programma. Nei casi in cui l'incorporamento della scomposizione non fornisce la stessa semantica dell'operazione, preferisci utilizzare custom_call.
Il campo version (il valore predefinito è 0) viene utilizzato per indicare quando cambiano le
semantiche di un composito.
Input
| Etichetta | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
numero variabile di valori |
| (I2) | name |
costante di tipo string |
| (I3) | composite_attributes |
dizionario degli attributi |
| (I4) | decomposition |
costante di tipo string |
| (I5) | version |
costante di tipo si32 |
Output
| Nome | Tipo |
|---|---|
results |
numero variabile di valori |
Vincoli
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Esempi
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenare
Semantica
Concatena inputs lungo la dimensione dimension nello stesso ordine degli argomenti
forniti e produce un tensore result. In termini più formali,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dove:
id = d0 + ... + dk-1 + kd.dè uguale adimensioned0, ... sono le dimensioni dellada dimensione diinputs.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C6) |
| (I2) | dimension |
costante di tipo si64 |
(C2), (C4), (C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5-C6) |
Vincoli
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...))ad eccezione didim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0])ad eccezione di:dim(result, dimension) = dim(inputs[0], dimension) + ....
Esempi
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
costante
Semantica
Produce un tensore output da una costante value.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | value |
costante | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
output |
tensore o tensore quantizzato | (C1) |
Vincoli
- (C1)
type(value) = type(output).
Esempi
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
convertire
Semantica
Esegue una conversione elemento per elemento da un tipo di elemento a un altro sul tensore operand e produce un tensore result.
Per le conversioni boolean-to-any-supported-type, il valore false viene
convertito in zero e il valore true viene convertito in uno. Per le conversioni
any-supported-type-to-boolean, un valore pari a zero viene convertito in
false, mentre i valori diversi da zero vengono convertiti in true. Continua a leggere per scoprire come
funziona per i tipi complessi.
Per le conversioni che coinvolgono interi, interi a virgola mobile o virgola mobile a virgola mobile, se il valore di origine può essere rappresentato esattamente nel tipo di destinazione, il valore del risultato è la rappresentazione esatta. In caso contrario, il comportamento è da definire (#180).
Per le conversioni che coinvolgono floating-point-to-integer, la parte frazionaria viene troncata. Se il valore troncato non può essere rappresentato nel tipo di destinazione, il comportamento è da definire (#180).
La conversione che coinvolge numeri complessi in numeri complessi segue lo stesso comportamento delle conversioni da virgola mobile a virgola mobile per la conversione di parti reali e immaginarie.
Per le conversioni da complesso a qualsiasi altro tipo e da qualsiasi altro tipo a complesso, il valore immaginario di origine viene ignorato o il valore immaginario di destinazione viene azzerato, rispettivamente. La conversione della parte reale segue le conversioni in virgola mobile.
In linea di principio, questa operazione potrebbe esprimere la dequantizzazione (conversione da tensori quantizzati a tensori regolari), la quantizzazione (conversione da tensori regolari a tensori quantizzati) e la riquantizzazione (conversione tra tensori quantizzati), ma al momento disponiamo di operazioni dedicate per questo scopo: uniform_dequantize per il primo caso d'uso e uniform_quantize per il secondo e il terzo. In futuro, queste due operazioni potrebbero essere unite
in convert (#1576).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore | (C1) |
Vincoli
- (C1)
shape(operand) = shape(result).
Esempi
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convoluzione
Semantica
Calcola i prodotti scalari tra finestre di lhs e sezioni di rhs e produce
result. Il seguente diagramma mostra come vengono calcolati gli elementi in result da
lhs e rhs utilizzando un esempio concreto.
In termini più formali, considera la seguente riformulazione degli input in termini di lhs
per poter esprimere finestre di lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Questa riformulazione utilizza le seguenti funzioni helper:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]dovej[d] = i[permutation[d]].
Se feature_group_count = 1 e batch_group_count = 1, allora per tutti
output_spatial_index in index_space(dim(result, output_spatial_dimensions...)),
result[result_shape(:, output_spatial_index, :)] = dot_product dove:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Questa funzionalità sembra inutilizzata, quindi in futuro prevediamo di rimuoverla (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Se feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Se batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
Costante tensore unidimensionale di tipo si64 |
(C2-C3), (C25) |
| (I4) | padding |
Costante tensore bidimensionale di tipo si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Costante tensore unidimensionale di tipo si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
Costante tensore unidimensionale di tipo si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
Costante tensore unidimensionale di tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
costante di tipo si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
numero variabile di enumerazioni di DEFAULT, HIGH e HIGHEST |
(C24) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C25-C28), (C30), (C32-34) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dato
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dato
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dato
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)è definito come:dim(lhs, input_batch_dimension) / batch_group_countseresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)seresult_dim = output_feature_dimension.num_windowsaltrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Se
is_per_axis_quantized(rhs), alloraquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Se
is_per_axis_quantized(result), alloraquantization_dimension(result) = output_feature_dimension. - Se
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Se
is_per_tensor_quantized(rhs), allorais_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Esempi
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semantica
Esegue l'operazione di coseno elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
cosda IEEE-754. - Per i numeri complessi: coseno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(cosine, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantica
Esegue il conteggio elemento per elemento del numero di bit zero iniziali nel tensore operand e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result).
Esempi
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
Semantica
Incapsula un'operazione definita dall'implementazione call_target_name che accetta
inputs e called_computations e produce results. has_side_effect,
backend_config e api_version possono essere utilizzati per fornire metadati aggiuntivi
definiti dall'implementazione.
Al momento, questa operazione contiene una raccolta di metadati piuttosto disorganizzata che riflette l'evoluzione organica della sua operazione omologa nel compilatore XLA. In futuro, prevediamo di unificare questi metadati (#741).
Input
| Etichetta | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
numero variabile di valori |
| (I2) | call_target_name |
costante di tipo string |
| (I3) | has_side_effect |
costante di tipo i1 |
| (I4) | backend_config |
costante di tipo string o dizionario degli attributi |
| (I5) | api_version |
costante di tipo si32 |
| (I6) | called_computations |
numero variabile di costanti di tipo string |
| (I7) | output_operand_aliases |
specificare le parti di aliasing negli output e negli operandi |
Output
| Nome | Tipo |
|---|---|
results |
numero variabile di valori |
(Supporto GPU XLA) Target custom_call speciali
Esistono tre call_target_name speciali correlati ai tipi buffer:
CreateBuffer crea un buffer non inizializzato, Pin crea un buffer inizializzato e Unpin dealloca un buffer e restituisce il contenuto del buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
Alias
Alcune operazioni custom_call potrebbero richiedere che una parte degli output e una parte degli
operandi condividano la stessa memoria. Questo valore può essere espresso tramite
output_operand_aliases. Una rappresentazione di una coppia di alias è costituita da un elenco di indici di tuple di output che rappresentano la parte di output e da un operand_index insieme a un elenco di indici di tuple di operandi che rappresentano la parte di operando. L'elenco degli indici di output
o di tuple di operandi è vuoto se il tipo corrispondente non è un tipo tuple
e può essere arbitrariamente lungo per un tipo di tupla arbitrariamente nidificato. Questo
è simile alla rappresentazione dell'alias XLA.
La parte di output e la parte di input in una coppia di alias devono avere lo stesso tipo. Per
le operazioni personalizzate di chiamata che non sono chiamate a CreateBuffer, Pin e Unpin, un
operando buffer può essere visualizzato al massimo in una coppia di alias e un output buffer
deve essere visualizzato in una coppia di alias.
Esempi
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
divisione
Semantica
Esegue la divisione elemento per elemento dei tensori dividendo lhs e divisore rhs e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi: divisione intera che produce il quoziente algebrico con qualsiasi parte frazionaria scartata.
- Per i numeri in virgola mobile:
divisionda IEEE-754. - Per i numeri complessi: divisione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Esempi
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantica
Calcola i prodotti scalari tra le sezioni di lhs e le sezioni di rhs e produce un
tensore result.
Più formalmente, result[result_index] = dot_product, dove:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_indexdovesize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)esize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
Per i tipi quantizzati, esegue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config controlla il compromesso tra velocità e precisione per
i calcoli sui backend dell'acceleratore. Può essere uno dei seguenti valori (al momento, la semantica di questi valori enum non è specificata, ma prevediamo di risolvere il problema in #755):
DEFAULT: Calcolo più veloce, ma approssimazione meno precisa del numero originale.HIGH: Calcolo più lento, ma approssimazione più accurata al numero originale.HIGHEST: Calcolo più lento, ma approssimazione più accurata al numero originale.
Un DotAlgorithm definisce le proprietà principali dell'algoritmo utilizzato per implementare
l'operazione dot, che definisce anche la precisione. Se i campi dell'attributo algoritmo
sono impostati, precision_config deve essere DEFAULT. DotAlgorithms
non hanno un valore predefinito, in quanto i parametri predefiniti sono definiti dall'implementazione. Pertanto, tutti i campi dell'algoritmo punto possono essere impostati su None per specificare un
algoritmo punto vuoto, che utilizzerà invece il valore precision_config.
I campi DotAlgorithm includono:
lhs_precision_typeerhs_precision_type, le precisioni a cui vengono arrotondati i lati sinistro e destro dell'operazione. I tipi di precisione sono indipendenti dai tipi di archiviazione degli input e dell'output.accumulation_typela precisione utilizzata per l'accumulo.lhs_component_count,rhs_component_countenum_primitive_operationsvengono applicati quando eseguiamo un algoritmo che scompone il lato sinistro e/o destro in più componenti ed esegue più operazioni "primitive" sui valori, di solito per emulare una precisione maggiore (ad es. Utilizzo del tipo di dati di intelligenza artificiale bfloat16 per calcoli di maggiore precisione: bf16_6x tf32_3x e così via). Per gli algoritmi senza decomposizione, questi valori devono essere impostati su1.allow_imprecise_accumulationper specificare se l'accumulo a precisione inferiore è consentito per alcuni passaggi (ad es.CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Attributi DotAlgorithm di esempio:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Spetta alle implementazioni decidere quali combinazioni sono supportate. In generale, non è garantito che ogni algoritmo sia supportato su ogni tipo di acceleratore dal consumer di StableHLO. Se un determinato algoritmo non è supportato, deve essere generato un errore anziché ricorrere a un'alternativa. La verifica StableHLO fornirà una verifica ottimale, impedendo l'utilizzo di algoritmi che non sono noti per essere supportati su qualsiasi hardware.
Consulta xla_data.proto > Algorithm
per alcuni valori di algoritmo supportati. Il ticket n. 2483 descrive il piano per creare un documento centralizzato sugli algoritmi supportati dal backend.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tensore o tensore quantizzato | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
numero variabile di enumerazioni di DEFAULT, HIGH e HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType o TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType o TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType o TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
costante di tipo si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
costante di tipo si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
costante di tipo si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
costante di tipo bool |
(C21) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C12), (C14), (C18-C20) |
Vincoli
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Se l'operazione utilizza tensori non quantizzati:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Se l'operazione utilizza tensori quantizzati:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Se
is_per_axis_quantized(rhs), alloraquantization_dimension(rhs)non inrhs_contracting_dimensions. - Se
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Se
is_per_tensor_quantized(rhs), allorais_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Se
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Esempi
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantica
Questa operazione è funzionalmente identica all'operazione
broadcast_in_dim, ma la forma del risultato viene specificata dinamicamente tramite output_dimensions.
L'operazione accetta anche gli attributi facoltativi known_expanding_dimensions, known_nonexpanding_dimensions
per esprimere la conoscenza statica del comportamento di espansione delle dimensioni.
Se non specificato, si presume che tutte le dimensioni possano essere in espansione.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
Tensore unidimensionale di tipo intero | (C7) |
| (I3) | broadcast_dimensions |
Tensore costante unidimensionale di tipo intero | (C2-C6) |
| (I4) | known_expanding_dimensions |
Tensore costante unidimensionale di tipo intero | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
Tensore costante unidimensionale di tipo intero | (C8-C9) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3), (C5-C7) |
Vincoli
- (C1)
element_type(result)è fornito da:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), tranne chequantization_dimension(operand),scales(operand)ezero_points(operand)potrebbero differire daquantization_dimension(result),scales(result)ezero_points(result)rispettivamente.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Per tutti i
dinaxes(operand):dim(operand, d) = 1odim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Se
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Se
dim(operand, quantization_dimension(operand)) = 1, allorascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Esempi
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantica
Questa operazione è funzionalmente identica all'operazione
convoluzione, ma il padding viene specificato dinamicamente tramite padding.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tensore o tensore quantizzato | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
Tensore bidimensionale di tipo intero | (C4) |
| (I4) | window_strides |
Costante tensore unidimensionale di tipo si64 |
(C2-C3) |
| (I5) | lhs_dilation |
Costante tensore unidimensionale di tipo si64 |
(C5-C6) |
| (I6) | rhs_dilation |
Costante tensore unidimensionale di tipo si64 |
(C7-C8) |
| (I7) | window_reversal |
Costante tensore unidimensionale di tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
costante di tipo si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
costante di tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
costante di tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
costante di tipo si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
costante di tipo si64 |
(C20) |
| (I15) | output_feature_dimension |
costante di tipo si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C19-C20) |
| (I17) | feature_group_count |
costante di tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
costante di tipo si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
numero variabile di enumerazioni di DEFAULT, HIGH e HIGHEST |
(C24) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C25-C27), (C29), (C31-C33) |
Vincoli
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dato
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dato
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dato
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)è definito come:dim(lhs, input_batch_dimension) / batch_group_countseresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)seresult_dim = output_feature_dimension.num_windowsaltrimenti, dove:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Se l'operazione utilizza tensori non quantizzati:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Se l'operazione utilizza tensori quantizzati:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Se
is_per_axis_quantized(rhs), alloraquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Se
is_per_axis_quantized(result), alloraquantization_dimension(result) = output_feature_dimension. - Se
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Se
is_per_tensor_quantized(rhs), allorais_per_tensor_quantized(result). - Se
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Esempi
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantica
Questa operazione è funzionalmente identica all'operazione
gather, con slice_sizes specificato dinamicamente come valore.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
tensore di tipo intero | (C2), (C3), (C13) |
| (I3) | slice_sizes |
Tensore unidimensionale di tipo intero | (C8), (C11-C13) |
| (I4) | offset_dims |
Costante tensore unidimensionale di tipo si64 |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
Costante tensore unidimensionale di tipo si64 |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
Costante tensore unidimensionale di tipo si64 |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
costante di tipo si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
costante di tipo i1 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5), (C13-C14) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)dove:batch_dim_sizes = shape(start_indices), ad eccezione della dimensionestart_indicescorrispondente aindex_vector_dim.offset_dim_sizes = shape(slice_sizes), ad eccezione delle dimensioni della dimensioneslice_sizescorrispondente acollapsed_slice_dims.combineposizionabatch_dim_sizessugli assi corrispondenti abatch_dimseoffset_dim_sizessugli assi corrispondenti aoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Esempi
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantica
Questa operazione è funzionalmente identica all'operazione
iota, ma la forma del risultato viene specificata dinamicamente tramite output_shape.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | output_shape |
Tensore unidimensionale di tipo intero | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Esempi
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantica
Questa operazione è funzionalmente identica all'operazione
pad, ma con edge_padding_low, edge_padding_high e interior_padding
specificati dinamicamente come valori.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
| (I3) | edge_padding_low |
Tensore unidimensionale di tipo intero | (C1), (C4) |
| (I4) | edge_padding_high |
Tensore unidimensionale di tipo intero | (C1), (C4) |
| (I5) | interior_padding |
Tensore unidimensionale di tipo intero | (C2-C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantica
Questa operazione è funzionalmente identica all'operazione
reshape, ma la forma del risultato viene specificata dinamicamente tramite output_shape.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
| (I2) | output_shape |
Tensore unidimensionale di tipo intero | (C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1-C4) |
Vincoli
- (C1)
element_type(result)è fornito da:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), tranne chequantization_dimension(operand)equantization_dimension(result)potrebbero essere diversi.
- (C2)
size(operand) = size(result). - (C3) Se
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantica
Estrae una sezione da operand utilizzando indici iniziali calcolati dinamicamente
e produce un tensore result. start_indices contengono gli indici iniziali della
sezione per ogni dimensione soggetta a potenziali aggiustamenti e slice_sizes
contengono le dimensioni della sezione per ogni dimensione. In termini più formali,
result[result_index] = operand[operand_index] dove:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
| (I2) | start_indices |
numero variabile di tensori di tipo intero a 0 dimensioni | (C2), (C3) |
| (I3) | slice_sizes |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Esempi
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantica
Produce un tensore result uguale al tensore operand, tranne per il fatto che
la sezione che inizia a start_indices viene aggiornata con i valori in update.
In termini più formali, result[result_index] è definito come:
update[update_index]se0 <= update_index < shape(update)dove:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
operand[result_index]altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C4), (C6) |
| (I2) | update |
tensore o tensore quantizzato per tensore | (C2), (C3), (C6) |
| (I3) | start_indices |
numero variabile di tensori di tipo intero a 0 dimensioni | (C4), (C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Esempi
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
esponenziale
Semantica
Esegue l'operazione esponenziale elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
expda IEEE-754. - Per i numeri complessi: esponenziale complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantica
Esegue l'operazione esponenziale meno 1 elemento per elemento sul tensore operand e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
expm1da IEEE-754. - Per i numeri complessi: esponenziale complesso meno uno.
- Per i tipi quantizzati:
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
Semantica
Esegue le trasformate di Fourier diretta e inversa per input/output reali e complessi.
fft_type è uno dei seguenti valori:
FFT: Inoltra FFT da complesso a complesso.IFFT: FFT inversa da complesso a complesso.RFFT: Inoltra FFT da reale a complesso.IRFFT: FFT reale-complessa inversa (ovvero prende numeri complessi e restituisce numeri reali).
Più formalmente, data la funzione fft che accetta tensori unidimensionali di tipi complessi come input, produce tensori unidimensionali degli stessi tipi come output e calcola la trasformata di Fourier discreta:
Per fft_type = FFT, result è definito come il risultato finale di una serie di calcoli L
in cui L = size(fft_length). Ad esempio, per L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Inoltre, data la funzione ifft che ha la stessa firma di tipo e calcola l'inverso di fft:
Per fft_type = IFFT, result è definito come l'inverso dei calcoli
per fft_type = FFT. Ad esempio, per L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
Inoltre, data la funzione rfft che accetta tensori unidimensionali di
tipi in virgola mobile, produce tensori unidimensionali di tipi complessi con la
stessa semantica in virgola mobile e funziona nel seguente modo:
rfft(real_operand) = truncated_resultdovecomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
Quando la trasformata di Fourier discreta viene calcolata per operandi reali, i primi
N/2 + 1 elementi del risultato definiscono in modo univoco il resto del risultato,
quindi il risultato di rfft viene troncato per evitare di calcolare elementi ridondanti.
Per fft_type = RFFT, result è definito come il risultato finale di una serie di calcoli L
in cui L = size(fft_length). Ad esempio, per L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Infine, data la funzione irfft che ha la stessa firma di tipo e
calcola l'inverso di rfft:
Per fft_type = IRFFT, result è definito come l'inverso dei calcoli
per fft_type = RFFT. Ad esempio, per L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo complesso o in virgola mobile | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
enum di FFT, IFFT, RFFT e IRFFT |
(C2), (C5) |
| (I3) | fft_length |
Costante tensore unidimensionale di tipo si64 |
(C1), (C3), (C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo complesso o in virgola mobile | (C2), (C4), (C5) |
Vincoli
- (C1)
size(fft_length) <= rank(operand). - (C2) La relazione tra i tipi di elementi
operanderesultvaria:- Se
fft_type = FFT,element_type(operand)eelement_type(result)hanno lo stesso tipo complesso. - Se
fft_type = IFFT,element_type(operand)eelement_type(result)hanno lo stesso tipo complesso. - Se
fft_type = RFFT,element_type(operand)è un tipo a virgola mobile eelement_type(result)è un tipo complesso con la stessa semantica a virgola mobile. - Se
fft_type = IRFFT,element_type(operand)è un tipo complesso eelement_type(result)è un tipo a virgola mobile con la stessa semantica a virgola mobile.
- Se
- (C3)
1 <= size(fft_length) <= 3. - (C4) Se tra
operanderesultè presente un tensorerealdi tipo in virgola mobile, allorashape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand)tranne:- Se
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Se
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Se
Esempi
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piano
Semantica
Esegue l'operazione floor elemento per elemento del tensore operand e produce un tensore result.
Implementa l'operazione roundToIntegralTowardNegative dalla specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(floor, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
gather
Semantica
Raccoglie le sezioni dal tensore operand dagli offset specificati in start_indices
e produce un tensore result.
Il seguente diagramma mostra come gli elementi in result vengono mappati sugli elementi in
operand utilizzando un esempio concreto. Il diagramma seleziona alcuni esempi di indici result
e spiega nel dettaglio a quali indici operand corrispondono.
Più formalmente, result[result_index] = operand[operand_index] dove:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexè definito come:start_indices[bi0, ..., :, ..., biN]dovebisono singoli elementi inbatch_indexe:viene inserito nell'indiceindex_vector_dim, seindex_vector_dim<rank(start_indices).[start_indices[batch_index]]altrimenti.
- Per
d_operandinaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])ifd_operand = start_index_map[d_start].full_start_index[d_operand] = 0altrimenti.
- Per
d_operandinaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sed_operand = operand_batching_dims[i_batching]ed_start = start_indices_batching_dims[i_batching].full_batching_index[d_operand] = 0altrimenti.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN]doveoisono singoli elementi inoffset_indexe0viene inserito negli indici dacollapsed_slice_dimseoperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Se indices_are_sorted è true, l'implementazione può presupporre che
start_indices siano ordinati rispetto a start_index_map, altrimenti il
comportamento è indefinito. Più formalmente, per tutti i i1 < i2 da indices(result),
full_start_index(i1) <= full_start_index(i2).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
tensore di tipo intero | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
Costante tensore unidimensionale di tipo si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
Costante tensore unidimensionale di tipo si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
Costante tensore unidimensionale di tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
Costante tensore unidimensionale di tipo si64 |
(C13-C17) |
| (I7) | start_index_map |
Costante tensore unidimensionale di tipo si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
costante di tipo si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
Costante tensore unidimensionale di tipo si64 |
(C9), (C12), (C20-C22) |
| (I10) | indices_are_sorted |
costante di tipo i1 |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C5), (C22-C23) |
Vincoli
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)dove:batch_dim_sizes = shape(start_indices), ad eccezione della dimensionestart_indicescorrispondente aindex_vector_dim.offset_dim_sizes = slice_sizesad eccezione delle dimensioni inslice_sizescorrispondenti acollapsed_slice_dimseoperand_batching_dims.combineposizionabatch_dim_sizessugli assi corrispondenti abatch_dimseoffset_dim_sizessugli assi corrispondenti aoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Esempi
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantica
Restituisce la dimensione del dimension specificato di operand. In termini più formali,
result = dim(operand, dimension). La semantica riguarda solo il componente della forma del tipo. Il tipo di elemento può essere qualsiasi cosa.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1) |
| (I2) | dimension |
costante di tipo si64 |
(C1) |
Output
| Nome | Tipo |
|---|---|
result |
Tensore unidimensionale di tipo si32 |
Vincoli
- (C1)
0 <= dimension < rank(operand).
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
Semantica
Estrae l'elemento nella posizione index della tupla operand e produce un
result. Più formalmente, result = operand[index].
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tupla | (C1), (C2) |
| (I2) | index |
costante di tipo si32 |
(C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
qualsiasi valore | (C2) |
Vincoli
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Esempi
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
se
Semantica
Produce l'output dall'esecuzione di esattamente una funzione da true_branch o
false_branch a seconda del valore di pred. Più formalmente, result =
pred ? true_branch() : false_branch().
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | pred |
Tensore unidimensionale di tipo i1 |
|
| (I2) | true_branch |
funzione | (C1-C3) |
| (I3) | false_branch |
funzione | (C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori, tensori quantizzati o token | (C3) |
Vincoli
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Esempi
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
Semantica
Estrae la parte immaginaria, elemento per elemento, da operand e produce un tensore result. Più formalmente, per ogni elemento x:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo complesso o in virgola mobile | (C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)è definito come:complex_element_type(element_type(operand))seis_complex(operand).element_type(operand)altrimenti.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
infeed
Semantica
Legge i dati dal feed e produce results.
La semantica di infeed_config è definita dall'implementazione.
results sono costituiti da valori di payload che vengono per primi e da un token che viene
per ultimo. In futuro, prevediamo di dividere il payload e il token in due output separati per migliorare la chiarezza
(#670).
Input
| Etichetta | Nome | Tipo |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
costante di tipo string |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori, tensori quantizzati o token | (C1-C3) |
Vincoli
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])ois_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Esempi
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantica
Riempie un tensore output con valori in ordine crescente a partire da zero
lungo la dimensione iota_dimension. Più formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
output |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
0 <= iota_dimension < rank(output).
Esempi
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantica
Esegue un controllo elemento per elemento per verificare se il valore in x è finito (ovvero non è
+Inf, -Inf o NaN) e produce un tensore y. Implementa l'operazione isFinite
dalla specifica IEEE-754. Per i tipi quantizzati, il risultato è
sempre true.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | x |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
y |
tensore di tipo booleano | (C1) |
Vincoli
- (C1)
shape(x) = shape(y).
Esempi
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
Semantica
Esegue l'operazione di logaritmo elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
logda IEEE-754. - Per i numeri complessi: logaritmo complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(log, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantica
Esegue l'operazione di logaritmo più uno elemento per elemento sul tensore operand e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
logp1da IEEE-754. - Per i numeri complessi:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - Per i tipi quantizzati:
dequantize_op_quantize(log_plus_one, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistico
Semantica
Esegue l'operazione logistica elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
division(1, addition(1, exp(-x)))da IEEE-754. - Per i numeri complessi: logistica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(logistic, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mappa
Semantica
Applica una funzione di mappatura computation a inputs lungo dimensions e
produce un tensore result.
Più formalmente, result[result_index] = computation(inputs...[result_index]).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C4) |
| (I2) | dimensions |
Costante tensore unidimensionale di tipo si64 |
(C3) |
| (I3) | computation |
funzione | (C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C4) |
Vincoli
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationè di tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>doveEi = element_type(inputs[i])eE' = element_type(result).
Esempi
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
massimo
Semantica
Esegue l'operazione max elemento per elemento sui tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per i numeri interi: numero intero massimo.
- Per i numeri in virgola mobile:
maximumda IEEE-754. - Per i numeri complessi: massimo lessicografico per la coppia
(real, imaginary). L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro prevediamo di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
minimo
Semantica
Esegue l'operazione min elemento per elemento sui tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per i numeri interi: numero intero minimo.
- Per i numeri in virgola mobile:
minimumda IEEE-754. - Per i numeri complessi: minimo lessicografico per la coppia
(real, imaginary). L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro prevediamo di rimuovere il supporto per i numeri complessi per questa operazione (#560). - Per i tipi quantizzati:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Esempi
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
moltiplicazione
Semantica
Esegue il prodotto elemento per elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: AND logico.
- Per i numeri interi: moltiplicazione di numeri interi.
- Per i numeri in virgola mobile:
multiplicationda IEEE-754. - Per i numeri complessi: moltiplicazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore o tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
negate
Semantica
Esegue la negazione elemento per elemento del tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi con segno: negazione del numero intero.
- Per i numeri interi non firmati: bitcast a numero intero firmato, negazione del numero intero, bitcast di nuovo a numero intero non firmato.
- Per i numeri in virgola mobile:
negateda IEEE-754. - Per i numeri complessi: negazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(negate, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
non
Semantica
Esegue l'operazione NOT elemento per elemento del tensore operand e produce un tensore result.
A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: NOT logico.
- Per i numeri interi: NOT bit a bit.
Argomenti
| Nome | Tipo | Vincoli |
|---|---|---|
operand |
tensore di tipo booleano o intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result).
Esempi
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
Semantica
Garantisce che le operazioni che producono operand vengano eseguite prima di qualsiasi
operazione che dipende da result e impedisce alle trasformazioni del compilatore
di spostare le operazioni oltre la barriera. A parte questo, l'operazione è
un'identità, ovvero result = operand.
Argomenti
| Nome | Tipo | Vincoli |
|---|---|---|
operand |
numero variabile di tensori, tensori o token quantizzati per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
numero variabile di tensori, tensori o token quantizzati per tensore | (C1) |
Vincoli
- (C1)
type(operand...) = type(result...).
Esempi
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
o
Semantica
Esegue l'operazione OR a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: OR logico.
- Per i numeri interi: OR a livello di bit.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero o booleano | (C1) |
| (I2) | rhs |
tensore di tipo intero o booleano | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero o booleano | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
outfeed
Semantica
Scrive inputs nel feed di output e produce un token result.
La semantica di outfeed_config è definita dall'implementazione.
Input
| Etichetta | Nome | Tipo |
|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati |
| (I2) | token |
token |
| (I3) | outfeed_config |
costante di tipo string |
Output
| Nome | Tipo |
|---|---|
result |
token |
Esempi
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
pad
Semantica
Espande operand aggiungendo spaziatura interna intorno al tensore e tra gli elementi
del tensore con il valore padding_value specificato.
edge_padding_low e edge_padding_high specificano la quantità di spaziatura interna aggiunta
all'estremità inferiore (accanto all'indice 0) e all'estremità superiore (accanto all'indice più alto) di
ogni dimensione rispettivamente. La quantità di spaziatura interna può essere negativa, dove il valore assoluto della spaziatura interna negativa indica il numero di elementi da rimuovere dalla dimensione specificata.
interior_padding specifica la quantità di spaziatura interna aggiunta tra due elementi
in ogni dimensione, che non può essere negativa. Il padding interno viene applicato
prima del padding dei bordi, in modo che il padding dei bordi negativo rimuova gli elementi
dall'operando con padding interno.
In termini più formali, result[result_index] è definito come:
operand[operand_index]ifresult_index = edge_padding_low + operand_index * (interior_padding + 1).padding_valuealtrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C1) |
| (I3) | edge_padding_low |
Costante tensore unidimensionale di tipo si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
Costante tensore unidimensionale di tipo si64 |
(C1), (C4) |
| (I5) | interior_padding |
Costante tensore unidimensionale di tipo si64 |
(C2-C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C3-C6) |
Vincoli
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Esempi
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantica
Produce partition_id del processo corrente.
Output
| Nome | Tipo |
|---|---|
result |
Tensore unidimensionale di tipo ui32 |
Esempi
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
Semantica
Esegue il conteggio elemento per elemento del numero di bit impostati nel tensore operand
e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(operand) = type(result).
Esempi
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
potenza
Semantica
Esegue l'esponenziazione elemento per elemento del tensore lhs per il tensore rhs e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi: elevamento a potenza di numeri interi.
- Per i numeri in virgola mobile:
powda IEEE-754. - Per i numeri complessi: elevamento a potenza complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(power, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
reale
Semantica
Estrae la parte reale, elemento per elemento, da operand e produce un tensore result. Più formalmente, per ogni elemento x:
real(x) = is_complex(x) ? real_part(x) : x.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo complesso o in virgola mobile | (C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)è definito come:complex_element_type(element_type(operand))seis_complex(operand).element_type(operand)altrimenti.
Esempi
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
recv
Semantica
Riceve i dati da un canale con channel_id e produce results.
Se is_host_transfer è true, l'operazione trasferisce i dati dall'host. In caso contrario, trasferisce i dati da un altro dispositivo in base ai valori di
source_target_pairs. Questo flag duplica le informazioni fornite in
channel_type, quindi in futuro prevediamo di conservarne solo uno
(#666). Se is_host_transfer
= false e source_target_pairs è None o vuoto, il comportamento è considerato
indefinito.
results sono costituiti da valori di payload che vengono per primi e da un token che viene
per ultimo. In futuro, prevediamo di dividere il payload e il token in due output separati per migliorare la chiarezza
(#670).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
costante di tipo si64 |
|
| (I3) | channel_type |
enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
costante di tipo i1 |
(C5-C6) |
| (I5) | source_target_pairs |
Costante tensore bidimensionale di tipo si64 |
(C1-C4), (C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori, tensori quantizzati o token | (C2-C4) |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, doveNè definito come:num_replicasse viene utilizzatocross_replica.num_partitionsse viene utilizzatocross_partition.
- (C5)
channel_typeè definito come:DEVICE_TO_HOSTseis_host_transfer = true,DEVICE_TO_DEVICEaltrimenti.
Esempi
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
reduce
Semantica
Applica una funzione di riduzione body a inputs e init_values lungo
dimensions e produce tensori results.
L'ordine delle riduzioni è definito dall'implementazione, il che significa che body e
init_values devono formare un monoide per garantire che l'operazione produca gli
stessi risultati per tutti gli input in tutte le implementazioni. Tuttavia, questa condizione
non vale per molte riduzioni popolari. Ad esempio, l'addizione in virgola mobile per
body e zero per init_values non formano effettivamente un monoide perché
l'addizione in virgola mobile non è associativa.
Più formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted) dove:
input_slices = inputs...[j0, ..., :, ..., jR-1], dove:vengono inseriti adimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)per alcuni alberi binarischeduledove:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
scheduleè un albero binario completo definito dall'implementazione la cui visita in ordine consiste in:- Valori di
input_slices_converted...[index], per tutti iindexinindex_space(input_slices_converted)in ordine lessicografico crescente diindex. - Interspersed with an implementation-defined amount of
init_values_convertedat implementation-defined positions.
- Valori di
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C7) |
| (I2) | init_values |
numero variabile di tensori unidimensionali o tensori quantizzati per tensore | (C2), (C3) |
| (I3) | dimensions |
Costante tensore unidimensionale di tipo si64 |
(C4), (C5), (C7) |
| (I4) | body |
funzione | (C6) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C3), (C7), (C8) |
Vincoli
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodyè di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)doveis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...), ad eccezione delle dimensioni della dimensioneinputs...corrispondenti adimensions, che non sono incluse. - (C8)
element_type(results[i]) = Eiper tutti iiin[0,N).
Esempi
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
Semantica
Esegue la conversione elemento per elemento di operand in un altro tipo di virgola mobile
che utilizza exponent_bits e mantissa_bits e di nuovo nel tipo di virgola mobile
originale e produce un tensore output.
Più formalmente:
- I bit della mantissa del valore originale vengono aggiornati per arrotondare il valore originale
al valore più vicino rappresentabile con
mantissa_bitsutilizzando la semanticaroundToIntegralTiesToEven. - Poi, se
mantissa_bitssono inferiori al numero di bit della mantissa del valore originale, i bit della mantissa vengono troncati amantissa_bits. - Se i bit dell'esponente del risultato intermedio non rientrano nell'intervallo fornito da
exponent_bits, il risultato intermedio va in overflow all'infinito utilizzando il segno originale o in underflow a zero utilizzando il segno originale. - Per i tipi quantizzati, esegue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
| (I2) | exponent_bits |
costante di tipo si32 |
(C2) |
| (I3) | mantissa_bits |
costante di tipo si32 |
(C3) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
output |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Esempi
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantica
All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, esegue la riduzione,
utilizzando computations, sui valori del tensore operand di ogni processo,
divide il risultato della riduzione lungo scatter_dimension in parti e distribuisce
le parti divise tra i processi per produrre result.
L'operazione divide la griglia di processi StableHLO in process_groups, che è
definito come segue:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Successivamente, all'interno di ogni process_group:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]per tutti isenderinprocess_group, dovereceiver_index = process_group.index(receiver).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
costante di tipo si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
Costante tensore bidimensionale di tipo si64 |
(C3-C5) |
| (I4) | channel_id |
costante di tipo si64 |
(C6) |
| (I5) | use_global_device_ids |
costante di tipo i1 |
(C6) |
| (I6) | computation |
funzione | (C7) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C8-C9) |
Vincoli
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)è definito come:num_replicasse viene utilizzatocross_replica.num_replicasse viene utilizzatocross_replica_and_partition.num_processesse viene utilizzatoflattened_ids.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Se
use_global_device_ids = true, allorachannel_id > 0. - (C7)
computationè di tipo(tensor<E>, tensor<E>) -> (tensor<E>)doveis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand)tranne:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Esempi
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantica
Applica una funzione di riduzione body a finestre di inputs e init_values
e produce results.
Il seguente diagramma mostra come vengono calcolati gli elementi in results... da inputs... utilizzando un esempio concreto.
Più formalmente,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(vedi reduce), dove:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
numero variabile di tensori unidimensionali o tensori quantizzati per tensore | (C1), (C13) |
| (I3) | window_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C4), (C5), (C15) |
| (I4) | window_strides |
Costante tensore unidimensionale di tipo si64 |
(C6), (C7), (C15) |
| (I5) | base_dilations |
Costante tensore unidimensionale di tipo si64 |
(C8), (C9), (C15) |
| (I6) | window_dilations |
Costante tensore unidimensionale di tipo si64 |
(C10), (C11), (C15) |
| (I7) | padding |
Costante tensore bidimensionale di tipo si64 |
(C12), (C15) |
| (I8) | body |
funzione | (C13) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C1), (C14-C16) |
Vincoli
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodyè di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)doveis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsdove:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eiper tutti iiin[0,N).
Esempi
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
resto
Semantica
Esegue il resto elemento per elemento dei tensori dividendo lhs e divisore rhs e
produce un tensore result.
Più formalmente, il segno del risultato viene preso dal dividendo e il valore assoluto del risultato è sempre inferiore al valore assoluto del divisore.
Il resto viene calcolato come lhs - d * rhs, dove d è dato da:
- Per i numeri interi:
stablehlo.divide(lhs, rhs). - Per i numeri in virgola mobile:
division(lhs, rhs)da IEEE-754 con attributo di arrotondamentoroundTowardZero. - Per i numeri complessi: TBD (#997).
- Per i tipi quantizzati:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
Per i tipi di elementi in virgola mobile, questa operazione è in contrasto con l'operazione remainder della specifica IEEE-754, in cui d è un valore intero più vicino al valore esatto di lhs/rhs con arrotondamento al numero pari.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
Semantica
Produce replica_id del processo corrente.
Output
| Nome | Tipo |
|---|---|
result |
Tensore unidimensionale di tipo ui32 |
Esempi
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
rimodellare
Semantica
Esegue il rimodellamento del tensore operand in un tensore result. Concettualmente, equivale a mantenere la stessa rappresentazione canonica, ma potenzialmente a cambiare la forma, ad esempio da tensor<2x3xf32> a tensor<3x2xf32> o tensor<6xf32>.
Più formalmente, result[result_index] = operand[operand_index] dove
result_index e operand_index hanno la stessa posizione nell'ordine
lessicografico di index_space(result) e index_space(operand).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C3) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1-C3) |
Vincoli
- (C1)
element_type(result)è fornito da:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), tranne chequantization_dimension(operand)equantization_dimension(result)potrebbero essere diversi.
- (C2)
size(operand) = size(result). - (C3) Se
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Esempi
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
inverti
Semantica
Inverte l'ordine degli elementi in operand lungo dimensions specificato
e produce un tensore result. In termini più formali,
result[result_index] = operand[operand_index] dove:
operand_index[d] = dim(result, d) - result_index[d] - 1sedindimensions.operand_index[d] = result_index[d]altrimenti.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1), (C3) |
| (I2) | dimensions |
Costante tensore unidimensionale di tipo si64 |
(C2), (C3) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C3) |
Vincoli
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Esempi
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantica
Genera numeri casuali utilizzando l'algoritmo rng_distribution e produce un tensore result di una determinata forma shape.
Se rng_distribution = UNIFORM, i numeri casuali vengono generati
seguendo la distribuzione uniforme nell'intervallo [a, b). Se a >= b,
il comportamento non è definito.
Se rng_distribution = NORMAL, i numeri casuali vengono generati
seguendo la distribuzione normale con media = a e deviazione standard = b.
Se b < 0, il comportamento non è definito.
Il modo esatto in cui vengono generati i numeri casuali è definito dall'implementazione. Ad esempio, potrebbero essere deterministici o meno e potrebbero utilizzare o meno uno stato nascosto.
In conversazioni con molti stakeholder, questa operazione è stata considerata effettivamente obsoleta, quindi in futuro prevediamo di esplorare la possibilità di rimuoverla (#597).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | a |
Tensore unidimensionale di tipo intero, booleano o con rappresentazione in virgola mobile | (C1), (C2) |
| (I2) | b |
Tensore unidimensionale di tipo intero, booleano o con rappresentazione in virgola mobile | (C1), (C2) |
| (I3) | shape |
Costante tensore unidimensionale di tipo si64 |
(C3) |
| (I4) | rng_distribution |
enum di UNIFORM e NORMAL |
(C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, booleano o in virgola mobile | (C1-C3) |
Vincoli
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Se
rng_distribution = NORMAL, allorais_float(a). - (C3)
shape(result) = shape.
Esempi
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantica
Restituisce un output riempito con bit casuali uniformi e uno stato di output aggiornato
output_state utilizzando l'algoritmo del generatore di numeri pseudocasuali rng_algorithm
dato uno stato iniziale initial_state. L'output è garantito per essere
una funzione deterministica di initial_state, ma non è garantito che sia
deterministico tra le implementazioni.
rng_algorithm è uno dei seguenti valori:
DEFAULT: Algoritmo definito dall'implementazione.THREE_FRY: Variante dell'algoritmo Threefry definita dall'implementazione.*PHILOX: Variante dell'algoritmo Philox definita dall'implementazione.*
* Vedi: Salmon et al. SC 2011. Numeri casuali paralleli: un gioco da ragazzi.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | rng_algorithm |
enum di DEFAULT, THREE_FRY e PHILOX |
(C2) |
| (I2) | initial_state |
Tensore unidimensionale di tipo ui64 |
(C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
output_state |
Tensore unidimensionale di tipo ui64 |
(C1) |
output |
tensore di tipo intero o in virgola mobile |
Vincoli
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)è definito come:- definito dall'implementazione se
rng_algorithm = DEFAULT. 2serng_algorithm = THREE_FRY.2o3serng_algorithm = PHILOX.
- definito dall'implementazione se
Esempi
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantica
Esegue l'arrotondamento elemento per elemento verso l'intero più vicino, interrompendo i pareggi allontanandosi
da zero, sul tensore operand e produce un tensore result. Implementa
l'operazione roundToIntegralTiesToAway dalla specifica IEEE-754. Per i tipi
quantizzati, esegue
dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantica
Esegue l'arrotondamento elemento per elemento al numero intero più vicino, risolvendo i pareggi
verso il numero intero pari, sul tensore operand e produce un tensore result. Implementa l'operazione roundToIntegralTiesToEven dalla specifica IEEE-754. Per i tipi quantizzati, esegue
dequantize_op_quantize(round_nearest_even, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantica
Esegue l'operazione di radice quadrata reciproca elemento per elemento sul tensore operand e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
rSqrtda IEEE-754. - Per i numeri complessi: radice quadrata reciproca complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(rsqrt, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
dispersione
Semantica
Produce tensori results uguali ai tensori inputs, tranne per il fatto che
diverse sezioni specificate da scatter_indices vengono aggiornate con i valori
updates utilizzando update_computation.
Il seguente diagramma mostra come gli elementi in updates... vengono mappati sugli elementi in
results... utilizzando un esempio concreto. Il diagramma seleziona alcuni esempi
di indici updates... e spiega in dettaglio a quali indici results... corrispondono.
Più formalmente, per tutti i update_index in index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexè definito come:scatter_indices[si0, ..., :, ..., siN]dovesisono singoli elementi inupdate_scatter_indexe:viene inserito nell'indiceindex_vector_dim, seindex_vector_dim<rank(scatter_indices).[scatter_indices[update_scatter_index]]altrimenti.
- Per
d_inputinaxes(inputs[0]),full_start_index[d_input] = start_index[d_start]ifd_input = scatter_dims_to_operand_dims[d_start].full_start_index[d_input] = 0altrimenti.
- Per
d_inputinaxes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sed_input = input_batching_dims[i_batching]ed_start = scatter_indices_batching_dims[i_batching].full_batching_index[d_input] = 0altrimenti.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN]dovewisono singoli elementi inupdate_window_indexe0viene inserito negli indici dainserted_window_dimseinput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
In base a questo, results = exec(schedule, inputs), dove:
scheduleè una permutazione definita dall'implementazione diindex_space(updates[0]).exec([update_index, ...], results) = exec([...], updated_results)dove:- Se
result_indexè nei limiti dishape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsè una copia diresultsconresults...[result_index]impostato suupdated_values....- In caso contrario
updated_results = results.
- Se
exec([], results) = results.
Se indices_are_sorted è true, l'implementazione può presupporre che
scatter_indices siano ordinati rispetto a scatter_dims_to_operand_dims,
altrimenti il comportamento è indefinito. Più formalmente, per tutti i i1 < i2 da
indices(result), full_start_index(i1) <= full_start_index(i2).
Se unique_indices è true, l'implementazione può presupporre che tutti gli indici result_index sparsi siano univoci. Se unique_indices è
true, ma gli indici a cui vengono distribuiti i dati non sono univoci, il comportamento è
indefinito.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
tensore di tipo intero | (C4), (C15), (C19), (C22) |
| (I3) | updates |
numero variabile di tensori o tensori quantizzati per tensore | (C3-C6), (C8) |
| (I4) | update_window_dims |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
Costante tensore unidimensionale di tipo si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
Costante tensore unidimensionale di tipo si64 |
(C19-C21) |
| (I9) | index_vector_dim |
costante di tipo si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
costante di tipo i1 |
|
| (I11) | unique_indices |
costante di tipo i1 |
|
| (I12) | update_computation |
funzione | (C23) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C24-C25) |
Vincoli
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)dove:update_scatter_dim_sizes = shape(scatter_indices), ad eccezione del fatto che le dimensioni della dimensionescatter_indicescorrispondenti aindex_vector_dimnon sono incluse.update_window_dim_sizes <= shape(inputs[0]), ad eccezione delle dimensioni ininputs[0]corrispondenti ainserted_window_dimseinput_batching_dims, che non sono incluse.combineposizionaupdate_scatter_dim_sizessugli assi corrispondenti aupdate_scatter_dimseupdate_window_dim_sizessugli assi corrispondenti aupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationè di tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), doveis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eiper tutti iiin[0,N).
Esempi
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
seleziona
Semantica
Produce un tensore result in cui ogni elemento viene selezionato dal tensore on_true o
on_false in base al valore dell'elemento corrispondente di pred.
Più formalmente, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], dove pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. Per i tipi quantizzati, esegue
dequantize_select_quantize(pred, on_true, on_false, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | pred |
tensore di tipo i1 |
(C1) |
| (I2) | on_true |
tensore o tensore quantizzato per tensore | (C1-C2) |
| (I3) | on_false |
tensore o tensore quantizzato per tensore | (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C2) |
Vincoli
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Esempi
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantica
Distribuisce i valori del tensore source utilizzando scatter in base al
risultato di reduce_window del tensore input utilizzando select e produce
un tensore result.
Il seguente diagramma mostra come vengono calcolati gli elementi in result da
operand e source utilizzando un esempio concreto.
Più formalmente:
selected_values = reduce_window_without_init(...)con i seguenti input:inputs = [operand].window_dimensions,window_stridesepadding, che vengono utilizzati così come sono.base_dilations = windows_dilations = 1.bodyè definito come:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;dove
E = element_type(operand)ereduce_window_without_initfunzionano esattamente comereduce_window, tranne per il fatto chescheduledell'reducesottostante (vedi reduce) non include i valori iniziali. Al momento non è specificato cosa succede se la finestra corrispondente non ha valori (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)dove:source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexseselected_values[source_index]ha l'elementooperanddioperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tensore o tensore quantizzato per tensore | (C1), (C2) |
| (I3) | init_value |
Tensore 0-dimensionale o tensore quantizzato per tensore | (C3) |
| (I4) | window_dimensions |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4), (C5) |
| (I5) | window_strides |
Costante tensore unidimensionale di tipo si64 |
(C2), (C6), (C7) |
| (I6) | padding |
Costante tensore bidimensionale di tipo si64 |
(C2), (C8) |
| (I7) | select |
funzione | (C9) |
| (I8) | scatter |
funzione | (C10) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C11-C12) |
Vincoli
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windowsdove:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selectè di tipo(tensor<E>, tensor<E>) -> tensor<i1>doveE = element_type(operand). - (C10)
scatterè di tipo(tensor<E>, tensor<E>) -> tensor<E>doveis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Esempi
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
invia
Semantica
Invia inputs a un canale channel_id. Gli input vengono quindi inviati ad altri dispositivi
nell'ordine specificato da source_target_pairs. L'operazione produce un
token result.
Se is_host_transfer è true, l'operazione trasferisce i dati all'host. In caso contrario, trasferisce i dati a un altro dispositivo in base ai valori di
source_target_pairs. Questo flag duplica le informazioni fornite in
channel_type, quindi in futuro prevediamo di conservarne solo uno
(#666). Se is_host_transfer
= false e source_target_pairs è None o vuoto, il comportamento è considerato
indefinito.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati | |
| (I2) | token |
token |
|
| (I3) | channel_id |
costante di tipo si64 |
|
| (I4) | channel_type |
enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
costante di tipo i1 |
(C5-C6) |
| (I6) | source_target_pairs |
Costante tensore bidimensionale di tipo si64 |
(C1-C4), (C6) |
Output
| Nome | Tipo |
|---|---|
result |
token |
Vincoli
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, doveNè definito come:num_replicasse viene utilizzatocross_replica.num_partitionsse viene utilizzatocross_partition.
- (C5)
channel_typeè definito come:DEVICE_TO_HOSTseis_host_transfer = true,DEVICE_TO_DEVICEaltrimenti.
Esempi
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
Semantica
Esegue l'operazione di spostamento a sinistra elemento per elemento sul tensore lhs di rhs bit e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero | (C1) |
| (I2) | rhs |
tensore di tipo intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantica
Esegue l'operazione di spostamento aritmetico a destra elemento per elemento sul tensore lhs di
rhs bit e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero | (C1) |
| (I2) | rhs |
tensore di tipo intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
Semantica
Esegue l'operazione di spostamento logico a destra elemento per elemento sul tensore lhs di rhs
bit e produce un tensore result.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero | (C1) |
| (I2) | rhs |
tensore di tipo intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
firmare
Semantica
Restituisce il segno dell'elemento operand e produce un tensore result.
In termini più formali, per ogni elemento x, la semantica può essere espressa utilizzando la sintassi Python come segue:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Per i tipi quantizzati, esegue
dequantize_op_quantize(sign, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo intero con segno, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero con segno, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semantica
Esegue l'operazione seno elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
sinda IEEE-754. - Per i numeri complessi: seno complesso.
- Per i tipi quantizzati:
dequantize_op_quantize(sine, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantica
Estrae una sezione da operand utilizzando indici iniziali calcolati in modo statico
e produce un tensore result. start_indices contengono gli indici iniziali
della sezione per ogni dimensione, limit_indices contengono gli indici finali
(esclusivi) della sezione per ogni dimensione e strides contengono i passi
per ogni dimensione.
Più formalmente, result[result_index] = operand[operand_index] dove
operand_index = start_indices + result_index * strides.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato per tensore | (C1-C3), (C5) |
| (I2) | start_indices |
Costante tensore unidimensionale di tipo si64 |
(C2), (C3), (C5) |
| (I3) | limit_indices |
Costante tensore unidimensionale di tipo si64 |
(C2), (C3), (C5) |
| (I4) | strides |
Costante tensore unidimensionale di tipo si64 |
(C2), (C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato per tensore | (C1), (C5) |
Vincoli
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Esempi
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
ordinare
Semantica
Ordina le sezioni unidimensionali di inputs lungo la dimensione dimension,
in base a un comparator e produce results.
A differenza di input simili in altre operazioni, dimension consente valori negativi,
con la semantica descritta di seguito. In futuro, questa operazione potrebbe non essere consentita
per motivi di coerenza
(#1377).
Se is_stable è true, l'ordinamento è stabile, ovvero l'ordine relativo degli
elementi considerati uguali dal comparatore viene conservato. Nel caso in cui
esista un singolo input, due elementi e1 e e2 sono considerati
uguali dal comparatore se e solo se
comparator(e1, e2) = comparator(e2, e1) = false. Consulta la formalizzazione riportata di seguito
per scoprire come questa operazione viene generalizzata a più input.
Più formalmente, per tutti i result_index in index_space(results[0]):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1], doveriNsono singoli elementi inresult_indexe:viene inserito inadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- dove
sortordina una sezione unidimensionale in ordine non decrescente, prevedendo checomparator_togetherrestituiscatruese l'argomento a sinistra è minore del secondo argomento a destra. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | inputs |
numero variabile di tensori o tensori quantizzati per tensore | (C1-C5) |
| (I2) | dimension |
costante di tipo si64 |
(C4) |
| (I3) | is_stable |
costante di tipo i1 |
|
| (I4) | comparator |
funzione | (C5) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di tensori o tensori quantizzati per tensore | (C2), (C3) |
Vincoli
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, doveR = rank(inputs[0]). - (C5)
comparatorha il tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, doveEi = element_type(inputs[i]).
Esempi
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantica
Esegue l'operazione di radice quadrata elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
squareRootda IEEE-754. - Per i numeri complessi: radice quadrata complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(sqrt, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantica
Esegue la sottrazione elemento per elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri interi: sottrazione di numeri interi.
- Per i numeri in virgola mobile:
subtractionda IEEE-754. - Per i numeri complessi: sottrazione complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
| (I2) | rhs |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo intero, in virgola mobile o complesso oppure tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Esempi
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
Semantica
Esegue l'operazione di tangente elemento per elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
tanda IEEE-754. - Per i numeri complessi: tangente complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tan, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantica
Esegue l'operazione di tangente iperbolica elemento per elemento sul tensore operand e
produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i numeri in virgola mobile:
tanhda IEEE-754. - Per i numeri complessi: tangente iperbolica complessa.
- Per i tipi quantizzati:
dequantize_op_quantize(tanh, operand, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_type(operand) = baseline_type(result).
Esempi
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
transpose
Semantica
Permuta le dimensioni del tensore operand utilizzando permutation e produce un tensore result. Più formalmente, result[result_index] = operand[operand_index]
dove result_index[d] = operand_index[permutation[d]].
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore o tensore quantizzato | (C1-C4) |
| (I2) | permutation |
Costante tensore unidimensionale di tipo si64 |
(C2-C4) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore o tensore quantizzato | (C1), (C3-C4) |
Vincoli
- (C1)
element_type(result)è fornito da:element_type(operand), se!is_per_axis_quantized(operand).element_type(operand), tranne chequantization_dimension(operand)equantization_dimension(result)potrebbero essere diversi.
- (C2)
permutationè una permutazione dirange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) Se
is_per_axis_quantized(result), alloraquantization_dimension(operand) = permutation(quantization_dimension(result)).
Esempi
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantica
Risolve batch di sistemi di equazioni lineari con matrici dei coefficienti triangolari inferiori o superiori.
Più formalmente, dati a e b, result[i0, ..., iR-3, :, :] è la soluzione
di op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] quando left_side è
true o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] quando
left_side è false, risolvendo la variabile x dove op(a) è determinato
da transpose_a, che può essere uno dei seguenti:
NO_TRANSPOSE: esegui l'operazione utilizzandoacosì com'è.TRANSPOSE: esegui l'operazione sulla trasposizione dia.ADJOINT: esegui l'operazione sulla trasposta coniugata dia.
I dati di input vengono letti solo dal triangolo inferiore di a, se lower è true o
dal triangolo superiore di a, altrimenti. I dati di output vengono restituiti nello stesso triangolo;
i valori nell'altro triangolo sono definiti dall'implementazione.
Se unit_diagonal è true, l'implementazione può presupporre che gli elementi diagonali di a siano uguali a 1, altrimenti il comportamento è indefinito.
Per i tipi quantizzati, esegue
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | a |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1-C3) |
| (I2) | b |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1-C4) |
| (I3) | left_side |
costante di tipo i1 |
(C3) |
| (I4) | lower |
costante di tipo i1 |
|
| (I5) | unit_diagonal |
costante di tipo i1 |
|
| (I6) | transpose_a |
enum di NO_TRANSPOSE, TRANSPOSE e ADJOINT |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo a virgola mobile o complesso o tensore quantizzato per tensore | (C1) |
Vincoli
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) La relazione tra
shape(a)eshape(b)è definita come segue:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Esempi
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tupla
Semantica
Produce una tupla result dai valori val.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | val |
numero variabile di valori | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tupla | (C1) |
Vincoli
- (C1)
resultè di tipotuple<E0, ..., EN-1>, doveEi = type(val[i]).
Esempi
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
Semantica
Esegue la conversione elemento per elemento del tensore quantizzato operand in un tensore
in virgola mobile result in base ai parametri di quantizzazione definiti
dal tipo operand.
Più formalmente, result = dequantize(operand).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore quantizzato | (C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo in virgola mobile | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Esempi
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
Semantica
Esegue la conversione elemento per elemento del tensore in virgola mobile o del tensore quantizzato
operand in un tensore quantizzato result in base ai parametri di quantizzazione
definiti dal tipo result.
Più formalmente,
- Se
is_float(operand):result = quantize(operand, type(result)).
- Se
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
tensore di tipo in virgola mobile o quantizzato | (C1), (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore quantizzato | (C1), (C2) |
Vincoli
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Esempi
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
mentre
Semantica
Produce l'output dall'esecuzione della funzione body 0 o più volte mentre la
funzione cond restituisce true. Più formalmente, la semantica può essere espressa
utilizzando la sintassi Python nel seguente modo:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Il comportamento di un loop infinito è da definire (#383).
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | operand |
numero variabile di valori | (C1-C3) |
| (I2) | cond |
funzione | (C1) |
| (I3) | body |
funzione | (C2) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
results |
numero variabile di valori | (C3) |
Vincoli
- (C1)
condè di tipo(T0, ..., TN-1) -> tensor<i1>, doveTi = type(operand[i]). - (C2)
bodyha tipo(T0, ..., TN-1) -> (T0, ..., TN-1), doveTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Esempi
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
Semantica
Esegue l'operazione XOR elemento per elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, esegue le seguenti operazioni:
- Per i valori booleani: XOR logico.
- Per gli interi: XOR bit a bit.
Input
| Etichetta | Nome | Tipo | Vincoli |
|---|---|---|---|
| (I1) | lhs |
tensore di tipo booleano o intero | (C1) |
| (I2) | rhs |
tensore di tipo booleano o intero | (C1) |
Output
| Nome | Tipo | Vincoli |
|---|---|---|
result |
tensore di tipo booleano o intero | (C1) |
Vincoli
- (C1)
type(lhs) = type(rhs) = type(result).
Esempi
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
Dialect Interop
Al momento, i programmi StableHLO in natura a volte contengono operazioni che non sono definite da StableHLO.
Modulo, funzione, chiamata e ritorno
StableHLO utilizza operazioni MLIR upstream per ModuleOp, FuncOp, CallOp e ReturnOp. Questa operazione è stata eseguita per una migliore interoperabilità con i meccanismi MLIR esistenti, poiché molti pass utili sono scritti in modo da avere come target FuncOp e ModuleOp e molte pipeline di compilazione si aspettano che queste operazioni siano presenti. A queste operazioni vengono applicate garanzie di compatibilità completa. Se queste operazioni cambiano in modo incompatibile (ad es. rimozione), verranno aggiunte le operazioni StableHLO equivalenti per preservare la compatibilità.
CHLO
L'opset CHLO contiene operazioni di livello superiore che si decompongono in StableHLO. Al momento non sono previste garanzie di compatibilità per CHLO. Per garantire la compatibilità, il passaggio chlo-legalize-to-stablehlo deve essere utilizzato prima della serializzazione.
Operazioni sulle forme
È un caso d'uso comune nella community utilizzare determinate operazioni dei dialetti MLIR di base nei programmi StableHLO dinamici per eseguire calcoli di forma.
Più comunemente, includono operazioni sul dialetto shape come shape_of o num_elements, operazioni sul dialetto tensor come dim o from_elements e il tipo index integrato.
La RFC sul dinamismo > O2
indica che questi tipi non rientrano nell'ambito, ma è incluso un certo supporto per i tipi index
per scopi di interoperabilità. Non sono previste garanzie di compatibilità per queste
operazioni o questi tipi. Il pass shape-legalize-to-stablehlo
può essere utilizzato per convertire queste operazioni in operazioni StableHLO completamente supportate.
Operazioni deprecate
Esistono diverse operazioni StableHLO ereditate da MHLO che sono deprecate e in fase di rimozione da StableHLO. I dettagli completi di queste rimozioni sono disponibili in StableHLO v1.0 Cleanup #2283. Il problema di monitoraggio per questi ritiri è #2340.
Queste operazioni rientrano in alcune categorie:
- Categoria "Not in HLO" delle operazioni StableHLO. Inizialmente facevano parte
dell'opset StableHLO, ma in seguito è stato ritenuto che non si adattassero bene:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Operazioni inutilizzate: queste operazioni potrebbero essere state utili in passato, ma sono
sottosviluppate o le pipeline che le utilizzano sono state
refattorizzate in modo da non richiederle più. Sono inclusi
map,tuple(#598),get_tuple_element,rng,complexconfronti #560, e convoluzionewindow_reversal(#1181).
Alcune di queste operazioni possono essere rimosse facilmente in quanto possono essere espresse utilizzando
operazioni esistenti (broadcast, create_token, cross-replica-sum, dot,
unary_einsum) e verranno rimosse dopo il periodo di compatibilità esistente (6 mesi). Altri sono ancora in fase di esplorazione per la rimozione (einsum,
get_tuple_element, map, rng torch_index_select, tuple, complex
confronti, window_reversal). In attesa del feedback della community,
queste operazioni verranno rimosse o aggiunte alla specifica con supporto completo. Fino a quando
non saranno noti questi futuri sistemi operativi, è garantita solo la compatibilità per 6 mesi.
Esecuzione
Esecuzione sequenziale
Un programma StableHLO viene eseguito fornendo valori di input alla funzione main
e calcolando i valori di output. I valori di output di una funzione vengono calcolati
eseguendo il grafico delle operazioni con radice nell'operazione return corrispondente.
L'ordine di esecuzione è definito dall'implementazione, purché sia allineato al
flusso di dati, ovvero se le operazioni vengono eseguite prima dei loro utilizzi. In StableHLO, tutte le operazioni con effetti collaterali consumano un token e ne producono uno (più token possono essere multiplexati in un unico token tramite after_all), quindi l'ordine di esecuzione degli effetti collaterali è allineato anche al flusso di dati. Ad esempio, nel programma riportato di seguito
sono possibili due ordini di esecuzione: %0 → %1 → %2 → return e
%1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Più formalmente, un processo StableHLO è una combinazione di:
1) un programma StableHLO, 2) stati delle operazioni (non ancora eseguite,
già eseguite) e 3) valori intermedi su cui il processo sta lavorando.
Il processo inizia con i valori di input della funzione main, prosegue attraverso
il grafico delle operazioni che aggiornano gli stati delle operazioni e i valori intermedi e
termina con i valori di output. Ulteriori formalizzazioni sono da definire
(#484).
Esecuzione parallela
I programmi StableHLO possono essere eseguiti in parallelo, organizzati in una griglia di processi 2D
di num_replicas x num_partitions, entrambi di tipo ui32.
Nella griglia di processi StableHLO, num_replicas * num_partitions dei processi StableHLO
vengono eseguiti contemporaneamente. Ogni processo ha un process_id = (replica_id, partition_id) univoco, dove replica_id in replica_ids = range(num_replicas) e partition_id in partition_ids = range(num_partitions) hanno entrambi il tipo ui32.
Le dimensioni della griglia di processi sono note staticamente per ogni programma (in futuro, prevediamo di renderle una parte esplicita dei programmi StableHLO #650) e la posizione all'interno della griglia di processi è nota staticamente per ogni processo. Ogni processo ha
accesso alla sua posizione all'interno della griglia dei processi tramite le operazioni replica_id e
partition_id.
All'interno della griglia di processo, i programmi possono essere tutti uguali (nello stile "Singolo programma, più dati"), tutti diversi (nello stile "Più programmi, più dati") o qualcosa di intermedio. In futuro, prevediamo di introdurre il supporto di altre espressioni per definire programmi StableHLO paralleli, incluso GSPMD (#619).
All'interno della griglia dei processi, i processi sono per lo più indipendenti l'uno dall'altro: hanno stati di operazione separati, valori di input/intermedi/output separati e la maggior parte delle operazioni viene eseguita separatamente tra i processi, ad eccezione di un piccolo numero di operazioni collettive descritte di seguito.
Dato che l'esecuzione della maggior parte delle operazioni utilizza solo valori dello stesso processo, in genere è univoco fare riferimento a questi valori con i loro nomi.
Tuttavia, quando si descrivono le semantiche delle operazioni collettive, questo non è sufficiente e
dà origine alla notazione name@process_id per fare riferimento al valore name
all'interno di un determinato processo. Da questo punto di vista, name non qualificato può essere
considerato un'abbreviazione di name@(replica_id(), partition_id()).
L'ordine di esecuzione tra i processi è definito dall'implementazione, ad eccezione della sincronizzazione introdotta dalla comunicazione point-to-point e dalle operazioni collettive come descritto di seguito.
Comunicazione punto a punto
I processi StableHLO possono comunicare tra loro tramite
canali StableHLO. Un canale è rappresentato da un ID positivo di tipo
si64. Tramite varie operazioni, è possibile inviare valori ai canali e
riceverli dai canali.
Ulteriori formalizzazioni, ad esempio da dove provengono questi ID canale, come i programmi di elaborazione ne vengono a conoscenza e quale tipo di sincronizzazione viene introdotto, sono da definire (#484).
Comunicazione in streaming
Ogni processo StableHLO ha accesso a due interfacce di streaming:
- Feed da cui è possibile leggere.
- Outfeed su cui è possibile scrivere.
A differenza dei canali, che vengono utilizzati per la comunicazione tra processi e quindi hanno processi a entrambe le estremità, gli input e gli output hanno l'altra estremità definita dall'implementazione.
Ulteriori formalizzazioni, ad esempio come la comunicazione in streaming influisce sull'ordine di esecuzione e quale tipo di sincronizzazione viene introdotto, sono ancora da definire (#484).
Collective ops
In StableHLO sono presenti sei operazioni collettive: all_gather, all_reduce,
all_to_all, collective_broadcast, collective_permute e
reduce_scatter. Tutte queste operazioni dividono i processi nella griglia di processi StableHLO in gruppi di processi StableHLO ed eseguono un calcolo congiunto all'interno di ciascun gruppo di processi, indipendentemente dagli altri gruppi di processi.
All'interno di ogni gruppo di processi, le operazioni collettive possono introdurre una barriera di sincronizzazione. Ulteriori formalizzazioni, ad esempio l'elaborazione di quando esattamente avviene questa sincronizzazione, come esattamente i processi raggiungono questa barriera e cosa succede se non lo fanno, sono da definire (#484).
Se il gruppo di processi prevede una comunicazione tra partizioni, ovvero se
nel gruppo di processi sono presenti processi con ID partizione diversi, l'esecuzione
dell'operazione collettiva richiede un canale e l'operazione collettiva deve fornire un
channel_id positivo di tipo si64. La comunicazione tra repliche non richiede
canali.
I calcoli eseguiti dalle operazioni collettive sono specifici per le singole operazioni e sono descritti nelle sezioni delle singole operazioni riportate sopra. Tuttavia, le strategie in base alle quali la griglia dei processi viene suddivisa in gruppi di processi sono condivise tra queste operazioni e sono descritte in questa sezione. Più formalmente, StableHLO supporta le seguenti quattro strategie.
cross_replica
All'interno di ogni gruppo di processi si verificano solo comunicazioni tra repliche. Questa
strategia prende replica_groups, un elenco di elenchi di ID replica, e calcola
un prodotto cartesiano di replica_groups per partition_ids. replica_groups
deve avere elementi univoci e coprire tutti i replica_ids. Più formalmente, utilizzando
la sintassi Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]] e num_partitions = 2,
cross_replica produrrà
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
All'interno di ogni gruppo di processi si verificano solo comunicazioni tra partizioni. Questa strategia prende partition_groups, un elenco di elenchi di ID partizione, e calcola un prodotto cartesiano di partition_groups per replica_ids.
partition_groups deve avere elementi univoci e coprire tutti i partition_ids.
Più formalmente, utilizzando la sintassi Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per partition_groups = [[0, 1]] e num_replicas = 4,
cross_partition produrrà
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
Le comunicazioni tra repliche e tra partizioni possono avvenire all'interno di ciascun
gruppo di processi. Questa strategia prende replica_groups, un elenco di elenchi di
ID replica, e calcola i prodotti cartesiani di ogni replica_group per
partition_ids. replica_groups deve avere elementi univoci e coprire tutti i
replica_ids. Più formalmente, utilizzando la sintassi Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per replica_groups = [[0, 1], [2, 3]] e num_partitions = 2,
cross_replica_and_partition produrrà
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Questa strategia prende flattened_id_groups, un elenco di elenchi di ID processo "appiattiti" nel formato replica_id * num_partitions + partition_id, e li trasforma in ID processo. flattened_id_groups deve avere elementi univoci
e coprire tutti i process_ids. Più formalmente, utilizzando la sintassi Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Ad esempio, per flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]],
num_replicas = 4 e num_partitions = 2, flattened_ids produrrà
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Precisione
Al momento, StableHLO non fornisce garanzie sull'accuratezza numerica, ma ciò potrebbe cambiare in futuro (#1156).
Semantica di esecuzione dell'operazione quantizzata
L'interpretazione delle operazioni StableHLO quantizzate può variare a seconda dei requisiti e delle funzionalità hardware. Ad esempio, alcuni hardware potrebbero scegliere di interpretare le operazioni quantizzate utilizzando una strategia di tipo "dequantizzazione, esecuzione dell'operazione in virgola mobile e infine quantizzazione". Altri potrebbero eseguire l'intero calcolo con l'aritmetica degli interi. Di conseguenza, l'interpretazione delle operazioni StableHLO quantizzate è determinata esclusivamente dall'implementazione specifica. L'interpretazione della quantizzazione ibrida (#1575) deve basarsi sulla sua semantica come prescritto nella specifica (tramite 1792).
Errori
I programmi StableHLO vengono convalidati tramite un ampio insieme di vincoli per le singole operazioni, il che esclude molte classi di errori prima dell'esecuzione. Tuttavia, sono ancora possibili condizioni di errore, ad esempio tramite overflow di numeri interi, accessi fuori dai limiti e così via. A meno che non siano esplicitamente indicati, tutti questi errori comportano un comportamento definito dall'implementazione, ma questo potrebbe cambiare in futuro (#1157).
Eccezioni in virgola mobile
Come eccezione a questa regola, le eccezioni in virgola mobile nei programmi StableHLO
hanno un comportamento ben definito. Le operazioni che generano eccezioni definite dallo standard IEEE-754 (operazione non valida, divisione per zero, overflow, underflow o eccezioni inesatte) producono risultati predefiniti (come definiti nello standard) e continuano l'esecuzione senza generare il flag di stato corrispondente, in modo simile alla gestione delle eccezioni raiseNoFlag dello standard. Le eccezioni per le operazioni non standard (ad es. aritmetica complessa e alcune funzioni trascendentali) sono definite dall'implementazione.
Mancata corrispondenza delle forme
StableHLO supporta i tensori con forma dinamica. Tuttavia, le forme devono corrispondere in fase di runtime, altrimenti il comportamento non è definito. StableHLO non fornisce esplicitamente un'operazione che possa asserire che un tensore ha una determinata forma in fase di runtime. La generazione del codice corretto è responsabilità del produttore.
Come esempio specifico, il programma riportato di seguito è valido. Tuttavia, in fase di runtime, le forme esatte di %arg0 e %arg1 dovranno essere le stesse, altrimenti il comportamento del programma non è definito:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notazione
Per descrivere la sintassi, questo documento utilizza la variante ISO modificata della sintassi EBNF (ISO/IEC 14977:1996, Wikipedia), con due modifiche: 1) le regole sono definite utilizzando ::= anziché =,
2) La concatenazione viene espressa utilizzando l'accostamento anziché ,.
Per descrivere la semantica (ad esempio nelle sezioni "Tipi", "Costanti" e "Operazioni"), utilizziamo formule basate sulla sintassi Python estesa con il supporto per esprimere in modo conciso le operazioni sugli array, come descritto di seguito. Questa soluzione funziona bene per piccoli snippet di codice, ma in rari casi in cui sono necessari snippet di codice più grandi, utilizziamo la sintassi Python standard, che viene sempre introdotta in modo esplicito.
Formule
Vediamo come funzionano le formule in base a un esempio tratto dalle dot_general
specifiche. Uno dei vincoli per questa operazione è il seguente:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
I nomi utilizzati in questa formula provengono da due origini: 1) funzioni globali,
ad es. dim, 2) definizioni dei membri dell'elemento del programma corrispondente, ad es.
lhs, lhs_batching_dimensions, rhs e rhs_batching_dimensions input
definiti nella sezione "Input" di dot_general.
Come accennato in precedenza, la sintassi di questa formula è basata su Python con alcune estensioni orientate alla concisione. Per dare un senso alla formula, trasformiamola nella sintassi Python standard.
A) In queste formule, utilizziamo = per rappresentare l'uguaglianza, quindi il primo passaggio
per ottenere la sintassi Python consiste nel sostituire = con ==, come segue:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Inoltre, queste formule supportano i puntini di sospensione (...), che trasformano le espressioni scalari
in espressioni tensoriali. In breve, f(xs...) significa approssimativamente "per ogni
scalare x nel tensore xs, calcola uno scalare f(x) e poi restituisci tutti
questi risultati scalari insieme come risultato tensore". Nella sintassi Python standard,
la nostra formula di esempio diventa:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Grazie ai puntini di sospensione, spesso è possibile evitare di lavorare a livello di
singoli scalari. Tuttavia, in alcuni casi difficili, è possibile utilizzare una sintassi semi-informale di livello inferiore, come nella formula start_indices[bi0, ..., :, ..., biN] della specifica gather. Per brevità, non
forniamo un formalismo esatto per la traduzione di questa sintassi in Python standard, nella
speranza che sia comunque comprensibile in modo intuitivo caso per caso.
Se alcune formule specifiche ti sembrano poco chiare, comunicacelo e cercheremo di
migliorarle.
Inoltre, noterai che le formule utilizzano i puntini di sospensione per espandere tutti i tipi di elenchi, inclusi tensori, elenchi di tensori (che ad esempio possono derivare da un numero variabile di tensori) e così via. Questo è un altro ambito in cui non forniamo un formalismo esatto (ad esempio, gli elenchi non fanno nemmeno parte del sistema di tipi StableHLO) e ci affidiamo invece a una comprensibilità intuitiva.
C) L'ultimo strumento di notazione degno di nota che utilizziamo è la trasmissione implicita. Sebbene l'opset StableHLO non supporti la trasmissione implicita, le formule lo fanno, anche per favorire la concisione. In breve, se uno scalare viene utilizzato in un contesto in cui è previsto un tensore, lo scalare viene trasmesso alla forma prevista.
Per continuare con l'esempio di dot_general, ecco un altro vincolo:
0 <= lhs_batching_dimensions < rank(lhs). Come definito nella specifica dot_general, lhs_batching_dimensions è un tensore, mentre 0 e
rank(lhs) sono scalari. Dopo l'applicazione della trasmissione implicita, la formula
diventerà [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Se applicata a una determinata operazione dot_general, questa formula
restituisce un tensore di valori booleani. Quando le formule vengono utilizzate come vincoli, il vincolo è valido se la formula restituisce true o un tensore che contiene solo elementi true.
Nomi
Nelle formule, l'ambito lessicale include: 1) funzioni globali, 2) definizioni dei membri,
3) definizioni locali. Di seguito è riportato l'elenco delle funzioni globali. L'elenco delle definizioni degli elementi dipende dall'elemento del programma a cui viene applicata la notazione:
- Per le operazioni, le definizioni dei membri includono i nomi introdotti nelle sezioni "Input" e "Output".
- Per tutto il resto, le definizioni dei membri includono le parti strutturali dell'elemento del programma, denominate in base ai non terminali EBNF corrispondenti. Nella maggior parte dei casi, i nomi di queste parti strutturali vengono ottenuti convertendo i nomi dei non terminali in snake case (ad es.
IntegerLiteral=>integer_literal), ma a volte i nomi vengono abbreviati nel processo (ad es.QuantizationStorageType=>storage_type), nel qual caso i nomi vengono introdotti esplicitamente in modo simile alle sezioni "Input" / "Output" nelle specifiche dell'operazione. - Inoltre, le definizioni dei membri includono sempre
selfper fare riferimento all'elemento del programma corrispondente.
Valori
Quando vengono valutate, le formule funzionano con i seguenti tipi di valori:
1) Value (valori effettivi, ad es. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;
i tipi sono sempre noti),
2) Placeholder (valori futuri, ad es. lhs, rhs o result; i valori effettivi
non sono ancora noti, ma solo i tipi),
3) Type (tipi come definiti nella sezione "Tipi"),
4) Function (funzioni globali come definite nella sezione "Funzioni").
A seconda del contesto, i nomi possono fare riferimento a valori diversi. Più
in particolare, la sezione "Semantica" per le operazioni (e gli equivalenti per altri elementi
del programma) definisce la logica di runtime, quindi tutti gli input sono disponibili come Value.
Al contrario, la sezione "Vincoli" per le operazioni (e gli equivalenti) definisce
la logica "in fase di compilazione", ovvero qualcosa che viene in genere eseguito prima dell'esecuzione,
quindi solo gli input costanti sono disponibili come Value e gli altri input
sono disponibili solo come Placeholder.
| Nomi | In "Semantica" | In "Vincoli" |
|---|---|---|
| Funzioni globali | Function |
Function |
| Input costanti | Value |
Value |
| Input non costanti | Value |
Placeholder |
| Output | Value |
Placeholder |
| Definizioni locali | Dipende dalla definizione | Dipende dalla definizione |
Vediamo un esempio di operazione transpose:
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
Per questa operazione, permutation è una costante, quindi è disponibile come Value
sia nella semantica che nei vincoli. Al contrario, operand e result sono
disponibili come Value nella semantica, ma solo come Placeholder nei vincoli.
Funzioni
Costruzione di tipi
Non esistono funzioni che possono essere utilizzate per costruire tipi. Utilizziamo invece direttamente la sintassi dei tipi perché in genere è più concisa. Ad esempio,
(tensor<E>, tensor<E>) -> (tensor<E>) anziché function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
Funzioni sui tipi
element_typeè definito sui tipi di tensore e sui tipi di tensore quantizzati e restituisce, rispettivamente, la parteTensorElementTypeoQuantizedTensorElementTypedelTensorTypeoQuantizedTensorTypecorrispondente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueè una scorciatoia peris_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueè una scorciatoia peris_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolverifica se il tipoxpuò essere promosso al tipoy. QuandoxeysonoQuantizedTensorElementType, la promozione viene applicata solo astorage_type. Questa versione specifica della promozione viene attualmente utilizzata nel contesto del calcolo della riduzione (fai riferimento alla RFC per maggiori dettagli).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueè una scorciatoia peris_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Disponibile per tutti i tipi. Ad esempio,is_float(x)restituiscetruesexè unFloatType. Sexè un valore o un segnaposto, questa funzione è una scorciatoia peris_type_name(type(x)).max_value(x: Type) -> Valuerestituisce il valore massimo di unTensorElementType. Sexnon è unTensorElementType, restituisceNone.min_value(x: Type) -> Valuerestituisce il valore minimo possibile di unTensorElementType. Sexnon è unTensorElementType, restituisceNone.member_name(x: Value | Placeholder | Type) -> Any. Disponibile per tutte le definizioni dei membrimember_namedi tutti i tipi. Ad esempio,tensor_element_type(x)restituisce la parteTensorElementTypedi unTensorTypecorrispondente. Sexè un valore o un segnaposto, questa funzione è una scorciatoia permember_name(type(x)). Sexnon è un tipo con un membro appropriato o un valore o un segnaposto di questo tipo, restituisceNone.is_empty_algorithm(*args: Type)controlla se tutti i campi dell'algoritmo punto sono impostati suNone. Questo è necessario perché gli algoritmi dei punti hanno comportamenti predefiniti definiti dall'implementazione, quindi specificare un valore predefinito sarebbe errato.
Costruzione di valori
operation_name(*xs: Value | Type) -> Value. Disponibile per tutte le operazioni. Ad esempio,add(lhs, rhs)accetta due valori tensorelhserhse restituisce l'output della valutazione dell'operazioneaddcon questi input. Per alcune operazioni, ad esempiobroadcast_in_dim, i tipi di output sono "portanti", ovvero necessari per valutare un'operazione. In questo caso, la funzione accetta questi tipi come argomenti.
Funzioni sui valori
Sono disponibili tutti gli operatori e le funzioni di Python. Ad esempio, sono disponibili sia le notazioni subscription che slicing di Python per l'indicizzazione di tensori, tensori quantizzati e tuple.
to_destination_type(x: Value, destination_type: Type) -> Valueè definito sui tensori e restituisce il valore convertito dixin base atype(x)edestination_typecome segue:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
È in corso una discussione preliminare sull'unione delle operazioni convert, uniform_quantize e
uniform_dequantize (#1576).
Dopo l'unione, non abbiamo più bisogno della funzione precedente e possiamo utilizzare il nome dell'operazione
per convert.
is_nan(x: Value) -> Valueè definito sui tensori e restituiscetruese tutti gli elementi dixsonoNaNofalsealtrimenti. Sexnon è un tensore, restituisceNone.is_sorted(x: Value) -> Valueè definito sui tensori e restituiscetruese gli elementi dixsono ordinati in ordine crescente rispetto all'ordine lessicografico crescente dei relativi indici ofalsealtrimenti. Sexnon è un tensore, restituisceNone.is_unique(x: Value) -> Valueè definito sui tensori e restituiscetruesexnon ha elementi duplicati ofalsein caso contrario. Sexnon è un tensore, restituisceNone.member_name(x: Value) -> Anyè definito per tutte le definizioni dei membrimember_namedi tutti i valori. Ad esempio,real_part(x)restituisce la parteRealPartdi unComplexConstantcorrispondente. Sexnon è un valore che ha un membro appropriato, restituisceNone.same(x: Value) -> Valueè definito sui tensori e restituiscetruese gli elementi dixsono tutti uguali tra loro ofalsein caso contrario. Se il tensore non ha elementi, viene considerato "tutti uguali tra loro", ovvero la funzione restituiscetrue. Sexnon è un tensore, restituisceNone.split(x: Value, num_results: Value, axis: Value) -> Valueè definito sui tensori e restituiscenum_resultssezioni dixlungo l'asseaxis. Sexnon è un tensore odim(x, axis) % num_results != 0, restituisceNone.is_defined_in_parent_scope(x: Value) -> Valueè definito sulle stringhe e restituiscetruesexè il nome di una funzione definita nello stesso ambito della funzione principale dell'operazione pertinente.is_namespaced_op_name(x: Value) -> Valueè definito sulle stringhe e restituiscetruesexè un nome di operazione valido, ovvero rispetta la seguente espressione regolare:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Calcoli delle forme
axes(x: Value | Placeholder | Type) -> Valueè una scorciatoia perrange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valueè una scorciatoia pershape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listè una scorciatoia perlist(map(lambda axis: dim(x, axis), axes)).index_space(x: Value | Placeholder | Type) -> Valueè definito sui tensori e restituisce gli indicisize(x)per ilTensorTypecorrispondente ordinato in ordine lessicografico crescente, ovvero[0, ..., 0],[0, ..., 1], ...,shape(x) - 1. Sexnon è un tipo di tensore, un tipo di tensore quantizzato o un valore o un segnaposto di uno di questi tipi, restituisceNone.rank(x: Value | Placeholder | Type) -> Valueè una scorciatoia persize(shape(x)).shape(x: Value | Placeholder | Type) -> Valueè definito nella sezione "Funzioni sui tipi" tramitemember_name.size(x: Value | Placeholder | Type) -> Valueè una scorciatoia perreduce(lambda x, y: x * y, shape(x)).
Calcoli di quantizzazione
def baseline_element_type(x: Value | Placeholder | Type) -> Typeè una scorciatoia perelement_type(baseline_type(x)).baseline_typeè definito sui tipi di tensore e sui tipi di tensore quantizzati e li trasforma in un tipo "di base", ovvero un tipo con la stessa forma, ma con i parametri di quantizzazione del tipo di elemento reimpostati sui valori predefiniti. Viene utilizzato come trucco pratico per confrontare in modo uniforme i tipi di tensore e tensore quantizzato, cosa che si verifica spesso. Per i tipi quantizzati, questo consente di confrontare i tipi ignorando i parametri di quantizzazione, ovveroshape,storage_type,expressed_type,storage_min,storage_maxequantization_dimension(per il tipo quantizzato per asse) devono corrispondere, mascalesezero pointspossono differire.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizeè definito per i tipi di tensore quantizzati e li trasforma in tipi di tensore in virgola mobile. Ciò avviene convertendo gli elementi quantizzati che rappresentano valori interi del tipo di archiviazione nei corrispondenti valori in virgola mobile del tipo espresso utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizeè definito sui tipi di tensore in virgola mobile e li trasforma in tipi di tensore quantizzati. Ciò avviene tramite la conversione dei valori in virgola mobile del tipo espresso nei valori interi corrispondenti del tipo di archiviazione utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizeviene utilizzato per specificare i calcoli elemento per elemento sui tensori quantizzati. Esegue la dequantizzazione, ovvero trasforma gli elementi quantizzati nei tipi espressi, quindi esegue un'operazione e poi esegue la quantizzazione, ovvero trasforma i risultati nei tipi di archiviazione. Al momento, questa funzione funziona solo per la quantizzazione per tensore. La quantizzazione per asse è in corso di elaborazione (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opviene utilizzato per specificare la quantizzazione solo del peso per l'operazione ibrida che accetta l'operando di sinistra in virgola mobile e l'operando di destra in tipi quantizzati. Dequantizza gli input quantizzati nei tipi espressi ed esegue il calcolo in virgola mobile. Il tipo di elemento del tensore lhs float e il tipo espresso del tensore rhs quantizzato devono essere identici.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calcoli della griglia
cross_partition(replica_groups: Value) -> Value. Vedi la sezione "cross_replica" sopra.cross_replica(replica_groups: Value) -> Value. Vedi la sezione "cross_replica" sopra.cross_replica_and_partition(replica_groups: Value) -> Value. Consulta la sezione "cross_replica_and_partition" sopra.flattened_ids(replica_groups: Value) -> Value. Vedi la sezione "flattened_ids" sopra.
Dinamismo
I valori StableHLO possono avere dimensioni dinamiche, ad esempio tensor<?xi64>.
Tuttavia, i valori StableHLO non possono avere un numero dinamico di dimensioni (dinamismo
senza classificazione, ad es. tensor<*xi64>). Gli operandi e i risultati possono utilizzare dimensioni
dinamiche, anche se esistono vincoli sulle dimensioni. I vincoli verranno
verificati staticamente, se possibile, altrimenti verranno rimandati al runtime e
le mancate corrispondenze comporteranno un comportamento indefinito. Di seguito sono riportati gli esempi.
Mancata corrispondenza delle forme per le operazioni unarie elemento per elemento
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un programma di questo tipo è insolito, perché non è comune conoscere la forma del risultato ma non quella dell'input. Tuttavia, questo è un programma StableHLO valido. Non è possibile convalidare staticamente l'operazione abs in questo
programma perché la forma esatta dell'operando è sconosciuta. Tuttavia, le forme
sono sicuramente compatibili e questo può essere verificato staticamente: ? potrebbe risultare
2 in fase di runtime e non ci sarebbero problemi. Tuttavia, ? potrebbe
anche risultare un altro numero intero, nel qual caso il comportamento è indefinito.
Tieni presente che se le dimensioni di una dimensione sono dinamiche nel risultato, non può verificarsi un comportamento indefinito. Infatti, non esiste una dimensione "prevista", quindi non può esserci una mancata corrispondenza.
Mancata corrispondenza delle forme per le operazioni binarie elemento per elemento
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Per quanto riguarda le operazioni binarie elemento per elemento, le forme degli input e del risultato devono corrispondere in fase di runtime. In fase di compilazione, le dimensioni statiche devono essere uguali, altrimenti devono solo essere compatibili. Se una qualsiasi dimensione è dinamica negli input, potrebbe verificarsi un comportamento indefinito in fase di runtime, perché le dimensioni dinamiche potrebbero non corrispondere alle dimensioni corrispondenti nell'altro operando (statico o dinamico). Se tutti gli input sono statici, non importa se il risultato è dinamico o meno: le dimensioni note staticamente verranno controllate staticamente e le dimensioni dinamiche non impongono vincoli.
Mancata corrispondenza delle forme per le operazioni che prendono la forma di output come operando
Considera il seguente programma di giocattoli:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
I valori nell'operando della forma in fase di runtime devono corrispondere alla forma del risultato,
altrimenti il comportamento non è definito. ovvero, in fase di runtime %arg0 deve avere un valore di dense<[3, 4]> : tensor<2xi32>. Se l'operando della forma è costante, questo
può essere verificato staticamente. Se la forma del risultato è completamente dinamica, non
può esserci una mancata corrispondenza.