Specifica StableHLO

StableHLO è un set di operazioni per operazioni di alto livello (HLO) nei modelli di machine learning (ML). StableHLO funziona come livello di portabilità tra diversi framework ML e compilatori ML: i framework ML che producono programmi StableHLO sono compatibili con i compilatori ML che utilizzano programmi StableHLO.

Il nostro obiettivo è semplificare e accelerare lo sviluppo ML creando una maggiore interoperabilità tra vari framework ML (come TensorFlow, JAX e PyTorch) e compilatori ML (come XLA e IREE). A questo scopo, questo documento fornisce una specifica per il linguaggio di programmazione StableHLO.

Questa specifica contiene tre sezioni principali. Innanzitutto, la sezione Programmi descrive la struttura dei programmi StableHLO, costituiti da funzioni StableHLO, a loro volta composte da operazioni StableHLO. All'interno di questa struttura, la sezione Ops specifica la semantica delle singole operazioni. La sezione Esecuzione fornisce la semantica per tutte queste operazioni eseguite insieme all'interno di un programma. Infine, la sezione Notazione illustra la notazione utilizzata in tutta la specifica.

Per visualizzare le specifiche di una release precedente di StableHLO, apri il repository in corrispondenza della release taggata che ti interessa. Ad esempio, la specifica StableHLO v0.19.0. Per visualizzare le modifiche apportate a ogni picco di versione secondaria di StableHLO, fai riferimento al log della versione in VhloDialect.td.

Programmi

Program ::= {Func}

I programmi StableHLO sono costituiti da un numero arbitrario di funzioni StableHLO. Di seguito è riportato un programma di esempio con una funzione @main che ha 3 input (%image, %weights e %bias) e 1 output. Il corpo della funzione ha 6 operazioni.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funzioni

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Le funzioni StableHLO (chiamate anche funzioni con nome) hanno un identificatore, input/output e un corpo. In futuro, abbiamo in programma di introdurre metadati aggiuntivi per le funzioni al fine di ottenere una migliore compatibilità con HLO (#425, #626, #740, #744).

Identificatori

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Gli identificatori StabiliHLO sono simili a quelli di molti linguaggi di programmazione, con due peculiarità: 1) tutti gli identificatori hanno sigilli che distinguono diversi tipi di identificatori, 2) gli identificatori dei valori possono essere completamente numerici per semplificare la generazione di programmi StableHLO.

Tipi

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

I tipi di StableHLO sono classificati in tipi di valore (chiamati anche tipi di prima classe), che rappresentano i valori StableHLO e tipi non valore, che descrivono altri elementi del programma. I tipi StableHLO sono simili ai tipi in molti linguaggi di programmazione, la cui peculiarità principale è la natura specifica del dominio di StableHLO, che genera risultati insoliti (ad esempio, i tipi scalari non sono tipi di valori).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

I tipi di tensori rappresentano i tensori, ovvero array multidimensionali. Hanno una forma e un tipo di elemento, in cui una forma rappresenta dimensioni non negative o sconosciute in ordine crescente delle dimensioni corrispondenti (chiamate anche assi) numerate da 0 a R-1. Il numero di dimensioni R è chiamato ranking. Ad esempio, tensor<2x3xf32> è un tipo di tensore con forma 2x3 e tipo di elemento f32. Presenta due dimensioni (o, in altre parole, due assi), la 0a dimensione e la 1a dimensione, le cui dimensioni sono 2 e 3. Il suo ranking è 2.

Le forme possono essere parzialmente o completamente sconosciute (dinamiche), ad esempio tensor<?x2xf64> è parzialmente sconosciuto e tensor<?x?xf64> è completamente sconosciuto. Le dimensioni dinamiche sono rappresentate utilizzando un ?. Non è possibile rimuovere il ranking delle forme.

In futuro, abbiamo in programma di esplorare i tipi di tensore oltre alle dimensioni e ai tipi di elementi, ad esempio per includere layout (#629) e sparsità (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nome Tipo Vincoli
storage_type tipo intero (C1-C3), (C8)
storage_min costante intera (C1), (C3), (C7)
storage_max costante intera (C2), (C3), (C7)
expressed_type tipo a virgola mobile (C4)
quantization_dimension costante intera facoltativa (C10-C12)
scales numero variadico di costanti in virgola mobile (C4-C6), (C9), (C10), (C13)
zero_points numero variadico di costanti numeriche (C7-C9)

I tipi di elementi quantizzati rappresentano i valori interi di un tipo di archiviazione nell'intervallo da storage_min a storage_max (incluso) che corrispondono a valori in virgola mobile di un tipo espresso. Per un determinato valore intero i, il valore in virgola mobile corrispondente f può essere calcolato come f = (i - zero_point) * scale, dove scale e zero_point sono chiamati parametri di quantizzazione. storage_min e storage_max sono facoltativi nella grammatica, ma hanno valori predefiniti rispettivamente di min_value(storage_type) e max_value(storage_type). I tipi di elementi quantizzati hanno i seguenti vincoli:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Se is_empty(quantization_dimension), allora size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Al momento, QuantizationScale è una costante in virgola mobile, ma c'è un forte interesse per le scale basate sui numeri interi, rappresentate con moltiplicatori e variazioni. Prevediamo di esaminare questo aspetto nel prossimo futuro (#1404).

È in corso un dibattito sulla semantica di QuantizationZeroPoint, compreso il tipo, i valori e se possono esserci solo uno o più punti zero in un tipo di tensore quantizzato. In base ai risultati di questa discussione, la specifica relativa a zero punti potrebbe cambiare in futuro (#1405).

Un'altra discussione in corso riguarda la semantica di QuantizationStorageMin e QuantizationStorageMax per determinare se è necessario applicare vincoli a questi valori e ai valori dei tensori quantizzati (#1406).

Infine, prevediamo di esplorare la rappresentazione di scale e punti zero sconosciuti, analogamente a come intendiamo esplorare la rappresentazione di dimensioni delle dimensioni sconosciute (#1407).

I tipi di tensori quantizzati rappresentano i tensori con elementi quantizzati. Questi tensori sono esattamente gli stessi dei tensori normali, ad eccezione del fatto che i loro elementi hanno tipi di elementi quantiizzati, anziché tipi di elementi regolari.

Nei tensori quantizzati, la quantizzazione può essere per tensore, ovvero avere un scale e zero_point per l'intero tensore o per asse, ovvero avere più scales e zero_points, una coppia per sezione di una determinata dimensione quantization_dimension. Più formalmente, in un tensore t con quantizzazione per asse, ci sono sezioni dim(t, quantization_dimension) di quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] e così via. Tutti gli elementi nella ia sezione utilizzano scales[i] e zero_points[i] come parametri di quantizzazione. I tipi di tensore quantificati hanno i seguenti vincoli:

  • Per la quantizzazione per tensore:
    • Nessun vincolo aggiuntivo.
  • Per la quantizzazione per asse:
    • (C13) quantization_dimension < rank(self).
    • (C14) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

I tipi di token rappresentano i token, ovvero i valori opachi prodotti e consumati da alcune operazioni. I token vengono utilizzati per imporre l'ordine di esecuzione sulle operazioni, come descritto nella sezione Esecuzione.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

I tipi di tuple rappresentano tuple, ovvero elenchi eterogenei. Le tuple sono una funzionalità legacy solo per compatibilità con HLO. In HLO si usano tuple per rappresentare input e output variadi. In StableHLO, gli input e gli output variadici sono supportati in modo nativo e l'unico uso delle tuple in StableHLO è rappresentare in modo completo l'ABI HLO, dove ad esempio T, tuple<T> e tuple<tuple<T>> possono essere sostanzialmente diversi a seconda di una particolare implementazione. In futuro, abbiamo in programma di apportare modifiche ad HLO ABI, che potrebbero permetterci di rimuovere i tipi di tuple da StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

I tipi di elementi rappresentano gli elementi di tipi tensori. A differenza di molti linguaggi di programmazione, questi tipi non sono di prima classe in StableHLO. Ciò significa che i programmi StabiliHLO non possono rappresentare direttamente valori di questi tipi (di conseguenza, è idiomatico rappresentare valori scalari di tipo T con valori di tensori 0 dimensioni di tipo tensor<T>).

  • Il tipo booleano rappresenta i valori booleani true e false.
  • I tipi di numeri interi possono essere con segno (si) o non firmati (ui) e avere una delle larghezze in bit supportate (4, 8, 16, 32 o 64). I tipi siN firmati rappresentano valori interi da -2^(N-1) a 2^(N-1)-1 inclusi, mentre i tipi uiN senza firma rappresentano valori interi da 0 a 2^N-1 inclusi.
  • I tipi di virgola mobile possono essere uno dei seguenti:
  • I tipi complessi rappresentano valori complessi che hanno una parte reale e una parte immaginaria dello stesso tipo di elemento. I tipi complessi supportati sono complex<f32> (entrambe le parti sono di tipo f32) e complex<f64> (entrambe le parti sono di tipo f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

I tipi di funzione rappresentano funzioni sia con nome che anonime. Hanno tipi di input (l'elenco dei tipi sul lato sinistro di ->) e tipi di output (l'elenco dei tipi a destra di ->). In molti linguaggi di programmazione, i tipi di funzione sono di prima classe, ma non in StableHLO.

StringType ::= 'string'

Il tipo di stringa rappresenta le sequenze di byte. A differenza di molti linguaggi di programmazione, il tipo di stringa non è il primo livello in StableHLO e viene utilizzato solo per specificare metadati statici per gli elementi del programma.

Suite operativa

Le operazioni StableHLO (chiamate anche operazioni) rappresentano un insieme chiuso di operazioni di alto livello nei modelli di machine learning. Come discusso in precedenza, la sintassi stabile è fortemente ispirata a MLIR, che non è necessariamente l'alternativa più ergonomica, ma probabilmente è la soluzione migliore per l'obiettivo di StableHLO di creare una maggiore interoperabilità tra framework ML e compilatori ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Le operazioni StabiliHLO (chiamate anche operazioni) hanno un nome, input/output e una firma. Il nome è composto dal prefisso stablehlo. e da un mnemonico che identifica in modo univoco una delle operazioni supportate. Di seguito è riportato un elenco completo di tutte le operazioni supportate.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Le operazioni utilizzano gli input e producono gli output. Gli input sono classificati in valori di input (calcolati durante l'esecuzione), funzioni di input (fornite in modo statico, perché in StableHLO le funzioni non sono valori di prima classe) e attributi di input (forniti anche in modo statico). Il tipo di input e output consumati e prodotti da un'operazione dipende dalla sua mnemonica. Ad esempio, l'operazione add consuma 2 valori di input e produce 1 valore di output. In confronto, l'operazione select_and_scatter utilizza 3 valori di input, 2 funzioni di input e 3 attributi di input.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Le funzioni di input (chiamate anche funzioni anonime) sono molto simili alle funzioni con nome, tranne per il fatto che: 1) non hanno un identificatore (da qui il nome "anonimo "), 2) non dichiarano i tipi di output (i tipi di output vengono dedotti dall'operazione return all'interno della funzione).

La sintassi per le funzioni di input include una parte attualmente inutilizzata (vedi la produzione Unused sopra), disponibile per garantire la compatibilità con MLIR. Nel machine learning esiste un concetto più generale di "regioni", in cui possono essere connessi più "blocchi" di operazioni tramite jump ops. Questi blocchi hanno ID che corrispondono alla produzione Unused, in modo che possano essere distinti l'uno dall'altro. StableHLO non ha operazioni jump, quindi la parte corrispondente della sintassi MLIR è inutilizzata (ma è ancora presente).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Gli attributi di input hanno un nome e un valore che rappresentano una delle costanti supportate. Sono il modo principale per specificare i metadati statici per gli elementi del programma. Ad esempio, l'operazione concatenate utilizza l'attributo dimension per specificare la dimensione con cui sono concatenati i relativi valori di input. Analogamente, l'operazione slice utilizza più attributi come start_indices e limit_indices per specificare i limiti utilizzati per suddividere il valore di input.

Al momento, i programmi StableHLO in circolazione a volte contengono attributi non descritti in questo documento. In futuro, prevediamo di assorbire questi attributi nell'opset StableHLO o di impedirne la visualizzazione nei programmi StableHLO. Nel frattempo, ecco l'elenco di questi attributi:

  • layout (n. 629).
  • mhlo.frontend_attributes (n. 628).
  • mhlo.sharding (n. 619).
  • output_operand_aliases (#740).
  • Metadati sulla posizione (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La firma op è composta dai tipi di tutti i valori di input (l'elenco dei tipi sul lato sinistro di ->) e dai tipi di tutti i valori di output (l'elenco dei tipi a destra di ->). A proposito, i tipi di input sono ridondanti e anche i tipi di output sono quasi sempre ridondanti (perché per la maggior parte delle operazioni StableHLO, i tipi di output possono essere dedotti dagli input). Ciononostante, la firma operativa fa parte deliberatamente della sintassi StableHLO per la compatibilità con MLIR.

Di seguito è riportato un esempio di operazione la cui mnemonica è select_and_scatter. Utilizza 3 valori di input (%operand, %source e %init_value), 2 funzioni di input e 3 attributi di input (window_dimensions, window_strides e padding). Nota che la firma dell'operazione include solo i tipi dei suoi valori di input (ma non i tipi di funzioni di input e attributi che sono forniti in linea).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Costanti

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Le costanti StableHLO hanno un valore letterale e un tipo che insieme rappresentano un valore StableHLO. In genere, il tipo fa parte della sintassi costante, tranne nei casi in cui non sia ambigua (ad es. una costante booleana ha il tipo i1 in modo inequivocabile, mentre una costante intera può avere più tipi possibili).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Le costanti booleane rappresentano i valori booleani true e false. Le costanti booleane hanno il tipo i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Le costanti intere rappresentano valori interi tramite stringhe che utilizzano la notazione decimale o esadecimale. Altre basi, ad esempio binari o ottali, non sono supportate. Le costanti numeriche hanno i seguenti vincoli:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Le costanti a virgola mobile rappresentano valori in virgola mobile tramite stringhe che utilizzano la notazione decimale o scientifica. Inoltre, puoi utilizzare la notazione esadecimale per specificare direttamente i bit sottostanti nel formato a virgola mobile del tipo corrispondente. Le costanti a virgola mobile hanno i seguenti vincoli:

  • (C1) Se viene utilizzata la notazione non esadecimale, is_wellformed(float_literal, float_type).
  • (C2) Se viene utilizzata la notazione esadecimale, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Le costanti complesse rappresentano valori complessi utilizzando gli elenchi di una parte reale (viene per prima) e di una parte immaginaria (viene seconda). Ad esempio, (1.0, 0.0) : complex<f32> rappresenta 1.0 + 0.0i e (0.0, 1.0) : complex<f32> rappresenta 0.0 + 1.0i. L'ordine in cui queste parti vengono archiviate in memoria è definito dall'implementazione. Le costanti complesse hanno i seguenti vincoli:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Le costanti tensore rappresentano i valori tensori utilizzando elenchi nidificati specificati tramite la notazione NumPy. Ad esempio, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> rappresenta un valore tensore con la seguente mappatura dagli indici agli elementi: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. L'ordine in cui questi elementi vengono archiviati in memoria è definito dall'implementazione. Le costanti tensori hanno i seguenti vincoli:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), dove:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), dove:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • altrimenti false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Le costanti tensori quantificate rappresentano i valori dei tensori quantizzati utilizzando la stessa notazione delle costanti tensoriali, con elementi specificati come costanti del loro tipo di archiviazione. Le costanti tensori quantificate hanno i seguenti vincoli:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

I valori letterali stringa sono costituiti da byte specificati utilizzando caratteri ASCII e sequenze di escape. Poiché sono indipendenti dalla codifica, l'interpretazione di questi byte è definita dall'implementazione. I valori letterali stringa hanno il tipo string.

o MLOps.

abs

Semantica

Esegue l'operazione ASS di elemento a livello di operand e produce un tensore di result. A seconda del tipo di elemento, procedi come segue:

  • Per i numeri interi con segno: modulo intero.
  • Per i valori in virgola mobile: abs da IEEE-754.
  • Per i numeri complessi: modulo complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(abs, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di numero intero con segno, in virgola mobile o di tipo complesso oppure tensore quantizzato per tensore (C1-C2)

Output

Nome Tipo Vincoli
result tensore di tipo intero con segno o a virgola mobile o tensore quantizzato per tensore (C1-C2)

Vincoli

  • (C1) shape(result) = shape(operand).
  • (C2) Per baseline_element_type(result) si intende:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • baseline_element_type(operand) in caso contrario.

Esempi

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Altri esempi

add

Semantica

Esegue l'aggiunta a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: OR logico.
  • Per i numeri interi: addizione di numeri interi.
  • Per i valori in virgola mobile: addition da IEEE-754.
  • Per i numeri complessi: addizione complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(add, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato (C1-C6)
(I2) rhs tensore o tensore quantizzato (C1-C5), (C7)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1-C7)

Vincoli

  • Se l'operazione utilizza tensori non quantiizzati:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Se l'operazione utilizza tensori quantizzati:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Se is_per_axis_quantized(lhs), allora quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Se is_per_axis_quantized(rhs), allora quantization_dimension(rhs) = quantization_dimension(result).

Esempi

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Altri esempi

after_all

Semantica

Garantisce che le operazioni che producono inputs vengano eseguite prima di qualsiasi operazione che dipende da result. L'esecuzione di questa operazione non ha alcun effetto, esiste solo per stabilire dipendenze dei dati dal giorno result al giorno inputs.

Input

Etichetta Nome Tipo
(I1) inputs numero variadic di token

Output

Nome Tipo
result token

Esempi

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Altri esempi

all_gather

Semantica

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, concatena i valori del tensore operand di ogni processo lungo all_gather_dim e produce un tensore result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Successivamente, all'interno di ogni process_group:

  • operands@receiver = [operand@sender for sender in process_group] per tutti i receiver di process_group.
  • result@process = concatenate(operands@process, all_gather_dim) per tutti i process di process_group.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C6)
(I2) all_gather_dim costante di tipo si64 (C1), (C6)
(I3) replica_groups Costante del tensore bidimensionale di tipo si64 (C2-C4)
(I4) channel_id costante di tipo si64 (C5)
(I5) use_global_device_ids costante di tipo i1 (C5)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C6)

Vincoli

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) Per size(replica_groups) si intende:
    • num_replicas se si utilizza cross_replica.
    • num_replicas se si utilizza cross_replica_and_partition.
    • num_processes se si utilizza flattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Se use_global_device_ids = true, allora channel_id > 0.
  • (C6) type(result) = type(operand) eccetto:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Esempi

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Altri esempi

all_reduce

Semantica

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, applica una funzione di riduzione computation ai valori del tensore operand di ogni processo e produce un tensore result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Successivamente, all'interno di ogni process_group:

  • result@process[result_index] = exec(schedule) per un albero binario schedule dove:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule è una struttura binaria definita dall'implementazione il cui attraversamento in ordine è to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C5), (C6)
(I2) replica_groups numero variadi delle costanti tensoriali unidimensionali di tipo si64 (C1-C3)
(I3) channel_id costante di tipo si64 (C4)
(I4) use_global_device_ids costante di tipo i1 (C4)
(I5) computation funzione (C5)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C6-C7)

Vincoli

  • (C1) is_unique(replica_groups).
  • (C2) Per size(replica_groups) si intende:
    • num_replicas se si utilizza cross_replica.
    • num_replicas se si utilizza cross_replica_and_partition.
    • num_processes se si utilizza flattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Se use_global_device_ids = true, allora channel_id > 0.
  • (C5) computation ha il tipo (tensor<E>, tensor<E>) -> (tensor<E>), dove is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Esempi

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<4xi64>) -> tensor<4xi64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Altri esempi

all_to_all

Semantica

all_to_all

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, suddivide i valori del tensore operand lungo split_dimension in parti, distribuisce le parti divise tra i processi, concatena le parti sparse lungo concat_dimension e produce un tensore result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(replica_groups) se channel_id <= 0.
  • cross_partition(replica_groups) se channel_id > 0.

Successivamente, all'interno di ogni process_group:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) per tutti i sender in process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group] dove receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1-C3), (C9)
(I2) split_dimension costante di tipo si64 (C1), (C2), (C9)
(I3) concat_dimension costante di tipo si64 (C3), (C9)
(I4) split_count costante di tipo si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Costante del tensore bidimensionale di tipo si64 (C5-C8)
(I6) channel_id costante di tipo si64

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C9)

Vincoli

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) Per size(replica_groups) si intende:
    • num_replicas se si utilizza cross_replica.
    • num_partitions se si utilizza cross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand) eccetto:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Esempi

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Altri esempi

e

Semantica

Esegue l'operatore AND a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: AND logico.
  • Per i numeri interi: AND bit a bit.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo booleano o numero intero (C1)
(I2) rhs tensore di tipo booleano o numero intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo booleano o numero intero (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

Altri esempi

atan2

Semantica

Esegue l'operazione atan2 basata sugli elementi su lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: atan2 da IEEE-754.
  • Per i numeri complessi: atan2 complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)
(I2) rhs tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Esempi

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Altri esempi

batch_norm_grad

Semantica

Calcola i gradienti di diversi input della retropropagazione della retropropagazione di batch_norm_training da grad_output e produce i tensori grad_operand, grad_scale e grad_offset. Più formalmente, questa operazione può essere espressa come una scomposizione delle operazioni StableHLO esistenti utilizzando la sintassi Python, come segue:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Per i tipi quantizzati, esegue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1-C3), (C5)
(I2) scale Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4), (C5)
(I3) mean Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4)
(I4) variance Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4)
(I5) grad_output tensore di tipo in virgola mobile o tensore quantizzato per tensore (C2), (C3)
(I6) epsilon costante di tipo f32
(I7) feature_index costante di tipo si64 (C1), (C5)

Output

Nome Tipo Vincoli
grad_operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C2), (C3)
grad_scale Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4)
grad_offset Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4)

Vincoli

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale e grad_offset hanno lo stesso baseline_element_type.
  • (C3) operand, grad_output e grad_operand hanno la stessa forma.
  • (C4) scale, mean, variance, grad_scale e grad_offset hanno la stessa forma.
  • (C5) size(scale) = dim(operand, feature_index).

Esempi

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantica

Normalizza il tensore operand in tutte le dimensioni tranne la dimensione feature_index e produce un tensore result. In modo più formale, questa operazione può essere espressa come scomposizione delle operazioni StableHLO esistenti utilizzando la sintassi Python come segue:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Per i tipi quantizzati, esegue dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1-C7)
(I2) scale Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C3)
(I3) offset Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C4)
(I4) mean Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C5)
(I5) variance Tensore unidimensionale di tipo quantizzato in virgola mobile o per tensore (C2), (C6)
(I6) epsilon costante di tipo f32
(I7) feature_index costante di tipo si64 (C1), (C3-C6)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o tensore quantizzato per tensore (C2), (C7)

Vincoli

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance e result hanno lo stesso baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantica

Calcola la media e la varianza in tutte le dimensioni, tranne la dimensione feature_index, e normalizza il tensore operand producendo i tensori output, batch_mean e batch_var. In modo più formale, questa operazione può essere espressa come scomposizione di operazioni StableHLO esistenti utilizzando la sintassi Python, come segue:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Per i tipi quantizzati, esegue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)
(I2) scale Tensore unidimensionale di virgola mobile o per tensore quantizzato (C2), (C3)
(I3) offset Tensore unidimensionale di virgola mobile o per tensore quantizzato (C2), (C4)
(I4) epsilon costante di tipo f32 (C1), (C3-C6)
(I5) feature_index costante di tipo si64 (C1), (C3-C6)

Output

Nome Tipo Vincoli
output tensore di tipo in virgola mobile o tensore quantizzato per tensore (C7)
batch_mean Tensore unidimensionale di virgola mobile o per tensore quantizzato (C2), (C5)
batch_var Tensore unidimensionale di virgola mobile o per tensore quantizzato (C2), (C6)

Vincoli

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var e output hanno lo stesso baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Esempi

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantica

Esegue un'operazione bitcast sul tensore operand e produce un tensore result in cui i bit dell'intero tensore operand vengono reinterpretati utilizzando il tipo del tensore result.

Più formalmente, dati E = element_type(operand), E' = element_type(result) e R = rank(operand):

  • Se num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Se num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Se num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits restituisce la rappresentazione in memoria di un determinato valore e il suo comportamento è definito dall'implementazione perché la rappresentazione esatta dei tensori è definita dall'implementazione e anche l'esatta rappresentazione dei tipi di elementi è definita dall'implementazione.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C2)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1-C2)

Vincoli

  • (C1) Dati E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) e R = rank(operand):
    • Se num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Se num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) per tutti i 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Se num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) per tutti i 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Se is_complex(operand) or is_complex(result), allora is_complex(operand) and is_complex(result).

Esempi

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Altri esempi

broadcast_in_dim

Semantica

Espande le dimensioni e/o il ranking di un tensore di input duplicando i dati nel tensore operand e produce un tensore result. Più formalmente, result[result_index] = operand[operand_index] dove per tutti i d in axes(operand):

  • operand_index[d] = 0 se dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] in caso contrario.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C2), (C5-C6)
(I2) broadcast_dimensions Costante del tensore unidimensionale di tipo si64 (C2-C6)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1), (C3), (C5-C6)

Vincoli

  • (C1) element_type(result) è dato da:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand) ad eccezione del fatto che quantization_dimension(operand), scales(operand) e zero_points(operand) possono differire da quantization_dimension(result), scales(result) e zero_points(result) o negli altri casi.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Per tutti i valori d in axes(operand):
    • dim(operand, d) = 1 o
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Se is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Se dim(operand, quantization_dimension(operand)) = 1, allora scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Esempi

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Altri esempi

richiesta

Semantica

Genera l'output dall'esecuzione di una sola funzione da branches a seconda del valore di index. In forma più formale, result = selected_branch() dove:

  • selected_branch = branches[index] se 0 <= index < size(branches).
  • selected_branch = branches[-1] in caso contrario.

Input

Etichetta Nome Tipo Vincoli
(I1) index Tensore 0-dimensionale di tipo si32
(I2) branches numero variadi delle funzioni (C1-C4)

Output

Nome Tipo Vincoli
results numero variadi di tensori, tensori quantizzati o token (C4)

Vincoli

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Esempi

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Altri esempi

crt

Semantica

Esegue un'operazione di radice cubica a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: rootn(x, 3) da IEEE-754.
  • Per i numeri complessi: radice cubica complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(cbrt, operand, type(result))

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Altri esempi

ceil

Semantica

Esegue il ceil a livello di elemento del tensore operand e produce un tensore result. Implementa l'operazione roundToIntegralTowardPositive dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(ceil, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Altri esempi

Cholesky

Semantica

Calcola la scomposizione di Cholesky di un batch di matrici.

Più formalmente, per tutti i valori i in index_space(result), result[i0, ..., iR-3, :, :] è una decomposizione di Cholesky di a[i0, ..., iR-3, :, :], sotto forma di una matrice triangolare inferiore (se lower è true) o triangolare superiore (se lower è false). I valori di output nel triangolo opposto, ovvero il triangolo superiore rigido o il triangolo inferiore stretto, di conseguenza, sono definiti dall'implementazione.

Se esiste i in cui la matrice di input non è una matrice Hermitiana con definizione positiva, il comportamento è indefinito.

Per i tipi quantizzati, esegue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) a tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1-C3)
(I2) lower Costante del tensore 0-dimensionale di tipo i1

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Esempi

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

clampare

Semantica

Collega ogni elemento del tensore operand tra un valore minimo e un valore massimo e produce un tensore result. In forma più formale, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), dove min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Per i tipi quantizzati, esegue dequantize_op_quantize(clamp, min, operand, max, type(result)).

L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo in programma di rimuovere il supporto per i numeri complessi per questa operazione (#560).

Input

Etichetta Nome Tipo Vincoli
(I1) min tensore o tensore quantizzato per tensore (C1), (C3)
(I2) operand tensore o tensore quantizzato per tensore (C1-C4)
(I3) max tensore o tensore quantizzato per tensore (C2), (C3)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C4)

Vincoli

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Esempi

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Altri esempi

collective_broadcast

Semantica

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, invia il valore del tensore operand dal processo di origine ai processi di destinazione e produce un tensore result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(replica_groups) se channel_id <= 0.
  • cross_partition(replica_groups) se channel_id > 0.

Dopodiché, result@process viene fornito da:

  • operand@process_groups[i, 0] se esiste un i che fa sì che la procedura sia in process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) altrimenti.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C3)
(I2) replica_groups numero variadi delle costanti tensoriali unidimensionali di tipo si64 (C1), (C2)
(I3) channel_id costante di tipo si64

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C3)

Vincoli

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N dove N è definito come:
    • num_replicas se si utilizza cross_replica.
    • num_partitions se si utilizza cross_partition.
  • (C3) type(result) = type(operand).

Esempi

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantica

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, invia il valore del tensore operand dal processo di origine al processo di destinazione e produce un tensore result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(source_target_pairs) se channel_id <= 0.
  • cross_partition(source_target_pairs) se channel_id > 0.

Dopodiché, result@process viene fornito da:

  • operand@process_groups[i, 0], se esiste un i che process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) altrimenti.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C5)
(I2) source_target_pairs Costante del tensore bidimensionale di tipo si64 (C1-C4)
(I3) channel_id costante di tipo si64

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, dove N è definito come:
    • num_replicas se si utilizza cross_replica.
    • num_partitions se si utilizza cross_partition.
  • (C5) type(result) = type(operand).

Esempi

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Altri esempi

compare

Semantica

Esegue il confronto a livello di elemento dei tensori lhs e rhs in base a comparison_direction e compare_type e produce un tensore result.

I valori di comparison_direction e compare_type hanno la seguente semantica:

Per i tipi di elementi booleani e interi:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Per i tipi di elementi a virgola mobile con compare_type = FLOAT, l'operatore implementa le seguenti operazioni IEEE-754:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Per i tipi di elementi con virgola mobile con compare_type = TOTALORDER, l'operazione utilizza la combinazione di operazioni totalOrder e compareQuietEqual da IEEE-754.

Per i tipi di elementi complessi, il confronto lessicografico delle coppie (real, imag) viene eseguito utilizzando i valori comparison_direction e compare_type forniti. L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, pertanto in futuro prevediamo di rimuovere il supporto per questi numeri quando comparison_direction è GE, GT, LE o LT (#560).

Per i tipi quantizzati, esegue dequantize_compare(lhs, rhs, comparison_direction).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1-C3)
(I2) rhs tensore o tensore quantizzato per tensore (C1-C2)
(I3) comparison_direction enum di EQ, NE, GE, GT, LE e LT
(I4) compare_type enum di FLOAT, TOTALORDER, SIGNED e UNSIGNED (C3)

Output

Nome Tipo Vincoli
result tensore di tipo booleano (C2)

Vincoli

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) Per compare_type si intende:
    • SIGNED se is_signed_integer(element_type(lhs)).
    • UNSIGNED se is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT o TOTALORDER se is_float(element_type(lhs)).
    • FLOAT se is_complex(element_type(lhs)).

Esempi

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Altri esempi

complesso

Semantica

Esegue la conversione a livello di elemento in un valore complesso da una coppia di valori reali e immaginari, lhs e rhs, e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo f32 o f64 (C1-C3)
(I2) rhs tensore di tipo f32 o f64 (C1)

Output

Nome Tipo Vincoli
result tensore di tipo complesso (C2), (C3)

Vincoli

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) ha il tipo complex<E> dove E = element_type(lhs).

Esempi

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Altri esempi

composito

Semantica

Incapsula un'operazione composta (composta) da altre operazioni StableHLO, prendendo inputs e composite_attributes e producendo results. La semantica dell'operazione è implementata dall'attributo decomposition. L'op composite può essere sostituita con la sua scomposizione senza modificare la semantica del programma. Nei casi in cui l'incorporamento della scomposizione non fornisce la stessa semantica dell'op, è preferibile utilizzare custom_call.

Il campo version (l'impostazione predefinita è 0) viene utilizzato per indicare quando la semantica di un elemento composito cambia.

Input

Etichetta Nome Tipo
(I1) inputs numero variadic di valori
(I2) name costante di tipo string
(I3) composite_attributes dizionario degli attributi
(I4) decomposition costante di tipo string
(I5) version costante di tipo si32

Output

Nome Tipo
results numero variadic di valori

Vincoli

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Esempi

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

Altri esempi

concatenate

Semantica

Concatena inputs lungo la dimensione dimension nello stesso ordine degli argomenti specificati e produce un tensore result. In forma più formale, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dove:

  1. id = d0 + ... + dk-1 + kd.
  2. d è uguale a dimension e d0, ... sono da dimensioni di dimensione di inputs.

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1-C6)
(I2) dimension costante di tipo si64 (C2), (C4), (C6)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C5-C6)

Vincoli

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) tranne dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) eccetto:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Esempi

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Altri esempi

costante

Semantica

Genera un tensore output da una costante value.

Input

Etichetta Nome Tipo Vincoli
(I1) value costante (C1)

Output

Nome Tipo Vincoli
output tensore o tensore quantizzato (C1)

Vincoli

  • (C1) type(value) = type(output).

Esempi

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Altri esempi

effettuare una conversione

Semantica

Esegue una conversione a livello di elemento da un tipo di elemento a un altro sul tensore operand e produce un tensore result.

Per le conversioni di tipo boolean-to-any-supported-type, il valore false viene convertito in zero, mentre il valore true viene convertito in uno. Per le conversioni any-supported-type-to-boolean, il valore zero viene convertito in false, mentre i valori diversi da zero vengono convertiti in true. Vedi sotto per scoprire come funziona per i tipi complessi.

Per le conversioni che prevedono il valore numero intero a numero intero, da intero a virgola mobile o da virgola mobile a virgola mobile, se il valore di origine può essere rappresentato esattamente nel tipo di destinazione, il valore risultante è quella rappresentazione esatta. In caso contrario, il comportamento è da definire (#180).

Per le conversioni che coinvolgono floating-point-to-integer, la parte frazionaria viene troncata. Se il valore troncato non può essere rappresentato nel tipo di destinazione, il comportamento è da definire (#180).

Le conversioni da complesso a complesso seguono lo stesso comportamento delle conversioni da punto a virgola mobile per la conversione di parti reali e immaginarie.

Per le conversioni di tipo complex-to-any-other-type e any-other-type-to-complex, il valore immaginario di origine viene ignorato o il valore immaginario di destinazione viene azzerato rispettivamente. La conversione della parte reale segue le conversioni in virgola mobile.

In linea di principio, questa operazione potrebbe esprimere la dequantizzazione (conversione da tensori quantizzati a tensori regolari), la quantizzazione (conversione da tensori regolari a tensori quantizzati) e la riquantizzazione (conversione tra tensori quantizzati), ma al momento abbiamo operazioni dedicate: uniform_dequantize per il primo caso d'uso e uniform_quantize per il secondo e il terzo caso d'uso. In futuro, queste due operazioni potrebbero essere unite in convert (#1576).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore (C1)

Output

Nome Tipo Vincoli
result tensore (C1)

Vincoli

  • (C1) shape(operand) = shape(result).

Esempi

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Altri esempi

convoluzione

Semantica

Calcola i prodotti punteggiati tra le finestre di lhs e le sezioni di rhs, quindi produce result. Il seguente diagramma mostra come vengono calcolati gli elementi in result da lhs e rhs utilizzando un esempio concreto.

convoluzione

In modo più formale, valuta la seguente ristrutturazione degli input in termini di lhs per poter esprimere finestre di lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Questa rielaborazione utilizza le seguenti funzioni helper:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] dove j[d] = i[permutation[d]].

Se feature_group_count = 1 e batch_group_count = 1, per tutti i output_spatial_index in index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product dove:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Questa funzionalità sembra inutilizzata, perciò abbiamo intenzione di rimuoverla in futuro (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Se feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Se batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Per i tipi quantizzati, esegue dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tensore o tensore quantizzato (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Costante del tensore unidimensionale di tipo si64 (C2-C3), (C25)
(I4) padding Costante del tensore bidimensionale di tipo si64 (C4), (C25)
(I5) lhs_dilation Costante del tensore unidimensionale di tipo si64 (C5-C6), (C25)
(I6) rhs_dilation Costante del tensore unidimensionale di tipo si64 (C7-C8), (C25)
(I7) window_reversal Costante del tensore unidimensionale di tipo i1 (C9)
(I8) input_batch_dimension costante di tipo si64 (C10), (C13), (C25)
(I9) input_feature_dimension costante di tipo si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension costante di tipo si64 (C14), (C18)
(I12) kernel_output_feature_dimension costante di tipo si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C17-C18), (C25)
(I14) output_batch_dimension costante di tipo si64 (C20), (C25)
(I15) output_feature_dimension costante di tipo si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C19-C20), (C25)
(I17) feature_group_count costante di tipo si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count costante di tipo si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config numero variadi delle enumerazioni di DEFAULT, HIGH e HIGHEST (C24)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C25-C28), (C30), (C32-34)

Vincoli

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dati input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dato kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dati output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) Per dim(result, result_dim) si intende:
    • dim(lhs, input_batch_dimension) / batch_group_count se result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) se result_dim = output_feature_dimension.
    • num_windows; in caso contrario, dove:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Se l'operazione utilizza tensori non quantiizzati:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Se l'operazione utilizza tensori quantizzati:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Se is_per_axis_quantized(rhs), quindi quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Se is_per_axis_quantized(result), allora quantization_dimension(result) = output_feature_dimension.
    • Se is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Se is_per_tensor_quantized(rhs), allora is_per_tensor_quantized(result).
    • Se !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Esempi

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Altri esempi

coseno

Semantica

Esegue un'operazione coseno a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: cos da IEEE-754.
  • Per i numeri complessi: coseno complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(cosine, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Altri esempi

count_leading_zeros

Semantica

Esegue il conteggio a livello di elemento del numero di zero bit iniziali nel tensore operand e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero (C1)

Vincoli

  • (C1) type(operand) = type(result).

Esempi

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Altri esempi

custom_call

Semantica

Incapsula un'operazione call_target_name definita dall'implementazione che prende inputs e called_computations e produce results. has_side_effect, backend_config e api_version possono essere usati per fornire ulteriori metadati definiti dall'implementazione.

Al momento questa operazione contiene una raccolta piuttosto disorganizzata di metadati che riflette l'evoluzione organica della sua controparte nel compilatore XLA. In futuro, abbiamo in programma di unificare questi metadati (#741).

Input

Etichetta Nome Tipo
(I1) inputs numero variadic di valori
(I2) call_target_name costante di tipo string
(I3) has_side_effect costante di tipo i1
(I4) backend_config costante di tipo string
(I5) api_version costante di tipo si32
(I6) called_computations numero variadi delle costanti di tipo string

Output

Nome Tipo
results numero variadic di valori

Esempi

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

divisione

Semantica

Esegue la divisione per elemento dei tensori del dividendo lhs e del divisore rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i numeri interi: divisione intera che produce il quoziente algebrico con qualsiasi parte eliminata.
  • Per i valori in virgola mobile: division da IEEE-754.
  • Per i numeri complessi: divisione complessa.
  • Per i tipi quantizzati:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore (C1)
(I2) rhs tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Esempi

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Altri esempi

dot_general

Semantica

Calcola i prodotti punti tra le sezioni di lhs e le sezioni di rhs e produce un tensore result.

In forma più formale, result[result_index] = dot_product, dove:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index dove size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) e size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Per i tipi quantizzati, esegue dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Per i tipi quantizzati ibridi, esegue hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config controlla il compromesso tra velocità e precisione per i calcoli sui backend degli acceleratori. Può essere uno dei seguenti (al momento, la semantica di questi valori di enum è sottospecificata, ma abbiamo intenzione di affrontarla in #755):

  • DEFAULT: calcolo più veloce, ma meno accurata approssimazione al numero originale.
  • HIGH: calcolo più lento, ma approssimazione più accurata al numero originale.
  • HIGHEST: calcolo più lento, ma approssimazione più precisa al numero originale.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tensore o tensore quantizzato (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Costante del tensore unidimensionale di tipo si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Costante del tensore unidimensionale di tipo si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Costante del tensore unidimensionale di tipo si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config numero variadi delle enumerazioni di DEFAULT, HIGH e HIGHEST (C11)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C12), (C14), (C18-C20)

Vincoli

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Se l'operazione utilizza tensori non quantiizzati:
    • (C13) element_type(lhs) = element_type(rhs).
  • Se l'operazione utilizza tensori quantizzati:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Se is_per_axis_quantized(rhs), allora quantization_dimension(rhs) non in rhs_contracting_dimensions.
    • Se is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Se is_per_tensor_quantized(rhs), allora is_per_tensor_quantized(result).
    • Se !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).

Esempi

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Altri esempi

dynamic_broadcast_in_dim

Semantica

Questa operazione è identica dal punto di vista funzionale all'operazione broadcast_in_dim, ma la forma del risultato viene specificata in modo dinamico tramite output_dimensions.

L'operazione accetta anche gli attributi facoltativi known_expanding_dimensions, known_non_expanding_dimensions per esprimere conoscenze statiche sul comportamento espandibile delle dimensioni. Se non specificato, si presume che tutte le dimensioni siano potenzialmente in espansione.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensore unidimensionale di tipo intero (C7)
(I3) broadcast_dimensions Tensore costante unidimensionale di tipo intero (C2-C6)
(I4) known_expanding_dimensions Tensore costante unidimensionale di tipo intero (C8-C9)
(I5) known_non_expanding_dimensions Tensore costante unidimensionale di tipo intero (C8-C9)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1), (C3), (C5-C7)

Vincoli

  • (C1) element_type(result) è dato da:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand) ad eccezione del fatto che quantization_dimension(operand), scales(operand) e zero_points(operand) possono differire da quantization_dimension(result), scales(result) e zero_points(result) o negli altri casi.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Per tutti i valori d in axes(operand):
    • dim(operand, d) = 1 o
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Se is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Se dim(operand, quantization_dimension(operand)) = 1, allora scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_non_expanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_non_expanding_dimensions < rank(operand).

Esempi

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Altri esempi

dynamic_conv

Semantica

Questa operazione è sostanzialmente identica a quella di convolution, ma la spaziatura interna viene specificata in modo dinamico tramite padding.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tensore o tensore quantizzato (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensore bidimensionale di tipo intero (C4)
(I4) window_strides Costante del tensore unidimensionale di tipo si64 (C2-C3)
(I5) lhs_dilation Costante del tensore unidimensionale di tipo si64 (C5-C6)
(I6) rhs_dilation Costante del tensore unidimensionale di tipo si64 (C7-C8)
(I7) window_reversal Costante del tensore unidimensionale di tipo i1 (C9)
(I8) input_batch_dimension costante di tipo si64 (C10), (C13)
(I9) input_feature_dimension costante di tipo si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C12), (C13)
(I11) kernel_input_feature_dimension costante di tipo si64 (C14), (C18)
(I12) kernel_output_feature_dimension costante di tipo si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C17-C18)
(I14) output_batch_dimension costante di tipo si64 (C20)
(I15) output_feature_dimension costante di tipo si64 (C20), (C29)
(I16) output_spatial_dimensions Costante del tensore unidimensionale di tipo si64 (C19-C20)
(I17) feature_group_count costante di tipo si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count costante di tipo si64 (C10), (C15), (C22), (C23)
(I19) precision_config numero variadi delle enumerazioni di DEFAULT, HIGH e HIGHEST (C24)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C25-C27), (C29), (C31-C33)

Vincoli

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dati input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dato kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dati output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) Per dim(result, result_dim) si intende:
    • dim(lhs, input_batch_dimension) / batch_group_count se result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) se result_dim = output_feature_dimension.
    • num_windows; in caso contrario, dove:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Se l'operazione utilizza tensori non quantiizzati:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Se l'operazione utilizza tensori quantizzati:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Se is_per_axis_quantized(rhs), quindi quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Se is_per_axis_quantized(result), allora quantization_dimension(result) = output_feature_dimension.
    • Se is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Se is_per_tensor_quantized(rhs), allora is_per_tensor_quantized(result).
    • Se !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Esempi

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

Altri esempi

dynamic_gather

Semantica

Questa operazione è identica dal punto di vista funzionale dell'operazione di raccolta, con slice_sizes specificato dinamicamente come valore.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensore di tipo intero (C2), (C3), (C13)
(I3) slice_sizes Tensore unidimensionale di tipo intero (C8), (C11-C13)
(I4) offset_dims Costante del tensore unidimensionale di tipo si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Costante del tensore unidimensionale di tipo si64 (C1), (C6-C8), (C13)
(I6) start_index_map Costante del tensore unidimensionale di tipo si64 (C3), (C9), (C10)
(I7) index_vector_dim costante di tipo si64 (C2), (C3), (C13)
(I8) indices_are_sorted costante di tipo i1

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C5), (C13-C14)

Vincoli

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) dove:
    • batch_dim_sizes = shape(start_indices) tranne per il fatto che le dimensioni di start_indices corrispondenti a index_vector_dim non sono incluse.
    • offset_dim_sizes = shape(slice_sizes) tranne per il fatto che le dimensioni delle dimensioni in slice_sizes corrispondenti a collapsed_slice_dims non sono incluse.
    • combine posiziona batch_dim_sizes sugli assi corrispondenti a batch_dims e offset_dim_sizes sugli assi corrispondenti a offset_dims.
  • (C14) element_type(operand) = element_type(result).

Esempi

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Altri esempi

dynamic_iota

Semantica

Questa operazione è sostanzialmente identica a iota op, ma la forma del risultato viene specificata in modo dinamico tramite output_shape.

Input

Etichetta Nome Tipo Vincoli
(I1) output_shape Tensore unidimensionale di tipo intero (C1), (C2)
(I2) iota_dimension si64 (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C2)

Vincoli

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Esempi

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

Altri esempi

dynamic_pad

Semantica

Questa operazione è identica dal punto di vista funzionale di pad op, ma con edge_padding_low, edge_padding_high e interior_padding specificati in modo dinamico come valori.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C2), (C4)
(I2) padding_value tensore 0-dimensionale o tensore quantizzato per tensore (C1)
(I3) edge_padding_low Tensore unidimensionale di tipo intero (C1), (C4)
(I4) edge_padding_high Tensore unidimensionale di tipo intero (C1), (C4)
(I5) interior_padding Tensore unidimensionale di tipo intero (C2-C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C3-C6)

Vincoli

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Esempi

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Altri esempi

dynamic_reshape

Semantica

Questa operazione è identica dal punto di vista funzionale di rimodellazione, ma la forma del risultato viene specificata in modo dinamico tramite output_shape.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C3)
(I2) output_shape Tensore unidimensionale di tipo intero (C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1-C4)

Vincoli

  • (C1) element_type(result) è dato da:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand) ad eccezione del fatto che quantization_dimension(operand) e quantization_dimension(result) potrebbero variare in altro modo.
  • (C2) size(operand) = size(result).
  • (C3) Se is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Esempi

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

Altri esempi

dynamic_slice

Semantica

Estrae una sezione da operand utilizzando indici iniziali calcolati dinamicamente e produce un tensore result. start_indices contengono gli indici iniziali della sezione per ogni dimensione soggetta a potenziale aggiustamento e slice_sizes contengono le dimensioni della sezione per ogni dimensione. In modo più formale, result[result_index] = operand[operand_index] dove:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C2), (C4)
(I2) start_indices numero variadic di tensori 0-dimensionali di tipo intero (C2), (C3)
(I3) slice_sizes Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C5)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1), (C5)

Vincoli

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Esempi

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Altri esempi

dynamic_update_slice

Semantica

Genera un tensore result uguale al tensore operand, ad eccezione del fatto che la sezione che inizia da start_indices viene aggiornata con i valori in update. Più formalmente, per result[result_index] si intende:

  • update[update_index] se 0 <= update_index < shape(update) dove:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] in caso contrario.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1-C4), (C6)
(I2) update tensore o tensore quantizzato per tensore (C2), (C3), (C6)
(I3) start_indices numero variadic di tensori 0-dimensionali di tipo intero (C4), (C5)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Esempi

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Altri esempi

esponenziale

Semantica

Esegue un'operazione esponenziale a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: exp da IEEE-754.
  • Per i numeri complessi: esponenziale complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(exponential, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Altri esempi

exponential_minus_one

Semantica

Esegue un'operazione esponenziale a livello di elemento meno un'operazione sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: expm1 da IEEE-754.
  • Per i numeri complessi: esponenziale complesso meno uno.
  • Per i tipi quantizzati: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Altri esempi

fft

Semantica

Esegue le trasformate di Fourier diretta e inversa per input/output reali e complessi.

fft_type è uno dei seguenti:

  • FFT: inoltro FFT complesso a complesso.
  • IFFT: FFT da complesso a complesso inverso.
  • RFFT: inoltro di FFT reale a complesso.
  • IRFFT: FFT inverso da reale a complesso (ad esempio, richiede complesso, restituisce reale).

Più formalmente, data la funzione fft che prende tensori unidimensionali di tipi complessi come input, produce tensori unidimensionali degli stessi tipi dell'output e calcola la trasformata discreta di Fourier:

Per fft_type = FFT, result è definito come il risultato finale di una serie di calcoli L in cui L = size(fft_length). Ad esempio, per L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Inoltre, data la funzione ifft che ha la stessa firma del tipo e calcola l'inverso di fft:

Per fft_type = IFFT, result è definito come l'inverso dei calcoli di fft_type = FFT. Ad esempio, per L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Inoltre, data la funzione rfft, che prende i tensori unidimensionali dei tipi in virgola mobile, produce tensori unidimensionali di tipi complessi con la stessa semantica in virgola mobile e funziona nel seguente modo:

  • rfft(real_operand) = truncated_result dove
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Quando viene calcolata la trasformata discreta di Fourier per gli operandi reali, i primi N/2 + 1 elementi del risultato definiscono in modo univoco il resto del risultato, di conseguenza il risultato di rfft viene troncato per evitare il calcolo di elementi ridondanti).

Per fft_type = RFFT, result è definito come il risultato finale di una serie di calcoli L in cui L = size(fft_length). Ad esempio, per L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Infine, data la funzione irfft, che ha lo stesso tipo di firma e calcola l'inverso di rfft:

Per fft_type = IRFFT, result è definito come l'inverso dei calcoli di fft_type = RFFT. Ad esempio, per L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso (C1), (C2), (C4), (C5)
(I2) fft_type enum di FFT, IFFT, RFFT e IRFFT (C2), (C5)
(I3) fft_length Costante del tensore unidimensionale di tipo si64 (C1), (C3), (C4)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso (C2), (C4), (C5)

Vincoli

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relazione tra i tipi di elementi operand e result varia:
    • Se fft_type = FFT, element_type(operand) e element_type(result) hanno lo stesso tipo complesso.
    • Se fft_type = IFFT, element_type(operand) e element_type(result) hanno lo stesso tipo complesso.
    • Se fft_type = RFFT, element_type(operand) è un tipo con virgola mobile e element_type(result) è un tipo complesso con la stessa semantica in virgola mobile.
    • Se fft_type = IRFFT, element_type(operand) è un tipo complesso e element_type(result) è un tipo in virgola mobile con la stessa semantica in virgola mobile.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Se tra operand e result, esiste un tensore real di tipo in virgola mobile, quindi shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) tranne:
    • Se fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Se fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Esempi

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

Semantica

Esegue il valore minimo per elemento del tensore operand e produce un tensore result. Implementa l'operazione roundToIntegralTowardNegative dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(floor, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Altri esempi

raccogliere

Semantica

Raccoglie le sezioni dal tensore operand dagli offset specificati in start_indices e produce un tensore result.

Il seguente diagramma mostra in che modo gli elementi in result vengono mappati sugli elementi in operand utilizzando un esempio concreto. Il diagramma seleziona alcuni indici result di esempio e spiega nel dettaglio a quali indici operand corrispondono.

raccogliere

In forma più formale, result[result_index] = operand[operand_index] dove:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • Per start_index si intende:
    • start_indices[bi0, ..., :, ..., biN], dove bi sono singoli elementi in batch_index e : viene inserito nell'indice index_vector_dim, se index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] in caso contrario.
  • Per d_operand in axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) se d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 in caso contrario.
  • Per d_operand in axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] se d_operand = operand_batching_dims[i_batching] e d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 in caso contrario.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], dove oi sono singoli elementi di offset_index e 0 è inserito negli indici da collapsed_slice_dims e operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Se indices_are_sorted è true, l'implementazione può presumere che i valori start_indices siano ordinati rispetto a start_index_map, altrimenti il comportamento non è definito. Più formalmente, per tutti i i1 < i2 di indices(result), full_start_index(i1) <= full_start_index(i2).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tensore di tipo intero (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Costante del tensore unidimensionale di tipo si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Costante del tensore unidimensionale di tipo si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims Costante del tensore unidimensionale di tipo si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Costante del tensore unidimensionale di tipo si64 (C13-C17)
(I7) start_index_map Costante del tensore unidimensionale di tipo si64 (C3), (C18-C19)
(I8) index_vector_dim costante di tipo si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Costante del tensore unidimensionale di tipo si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted costante di tipo i1

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C5), (C22-C23)

Vincoli

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) dove:
    • batch_dim_sizes = shape(start_indices) tranne per il fatto che le dimensioni di start_indices corrispondenti a index_vector_dim non sono incluse.
    • offset_dim_sizes = slice_sizes tranne per il fatto che le dimensioni delle dimensioni in slice_sizes corrispondenti a collapsed_slice_dims e operand_batching_dims non sono incluse.
    • combine posiziona batch_dim_sizes sugli assi corrispondenti a batch_dims e offset_dim_sizes sugli assi corrispondenti a offset_dims.
  • (C23) element_type(operand) = element_type(result).

Esempi

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

Altri esempi

get_dimension_size

Semantica

Genera le dimensioni del valore dimension specificato di operand. In forma più formale, result = dim(operand, dimension). La semantica riguarda solo il componente forma del tipo. Il tipo di elemento può essere qualsiasi cosa.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1)
(I2) dimension costante di tipo si64 (C1)

Output

Nome Tipo
result Tensore 0-dimensionale di tipo si32

Vincoli

  • (C1) 0 <= dimension < rank(operand).

Esempi

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Altri esempi

get_tuple_element

Semantica

Estrae l'elemento nella posizione index della tupla operand e produce un elemento result. In forma più formale, result = operand[index].

Input

Etichetta Nome Tipo Vincoli
(I1) operand tuple (C1), (C2)
(I2) index costante di tipo si32 (C1), (C2)

Output

Nome Tipo Vincoli
result qualsiasi tipo supportato (C2)

Vincoli

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Esempi

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Altri esempi

if

Semantica

Genera l'output dall'esecuzione di una sola funzione da true_branch o false_branch, a seconda del valore di pred. In forma più formale, result = pred ? true_branch() : false_branch().

Input

Etichetta Nome Tipo Vincoli
(I1) pred Tensore 0-dimensionale di tipo i1
(I2) true_branch funzione (C1-C3)
(I3) false_branch funzione (C1), (C2)

Output

Nome Tipo Vincoli
results numero variadi di tensori, tensori quantizzati o token (C3)

Vincoli

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Esempi

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Altri esempi

immaginare

Semantica

Estrae la parte immaginaria, per elemento, da operand e produce un tensore result. In modo più formale, per ogni elemento x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso (C1), (C2)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile (C1), (C2)

Vincoli

  • (C1) shape(result) = shape(operand).
  • (C2) Per element_type(result) si intende:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • element_type(operand) in caso contrario.

Esempi

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Altri esempi

annuncio in-feed

Semantica

Legge i dati dal feed e produce results.

La semantica del campo infeed_config è definita dall'implementazione.

results è costituito da valori di payload che vengono prima e un token che termina per ultimo. In futuro, abbiamo in programma di suddividere il payload e il token in due output separati per migliorare la chiarezza (#670).

Input

Etichetta Nome Tipo
(I1) token token
(I2) infeed_config costante di tipo string

Output

Nome Tipo Vincoli
results numero variadi di tensori, tensori quantizzati o token (C1-C3)

Vincoli

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) o is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Esempi

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Altri esempi

iota

Semantica

Riempie un tensore output con valori in ordine crescente a partire da zero lungo la dimensione iota_dimension. In modo più formale,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Input

Etichetta Nome Tipo Vincoli
(I1) iota_dimension si64 (C1)

Output

Nome Tipo Vincoli
output tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) 0 <= iota_dimension < rank(output).

Esempi

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Altri esempi

is_finite

Semantica

Esegue un controllo a livello di elemento se il valore in x è finito (ovvero non è né +Inf, -Inf o NaN) e produce un tensore y. Implementa l'operazione isFinite dalla specifica IEEE-754. Per i tipi quantizzati, il risultato è sempre true.

Input

Etichetta Nome Tipo Vincoli
(I1) x tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
y tensore di tipo booleano (C1)

Vincoli

  • (C1) shape(x) = shape(y).

Esempi

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Altri esempi

log

Semantica

Esegue un'operazione logaritmica a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: log da IEEE-754.
  • Per i numeri complessi: logaritmo complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(log, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Altri esempi

log_plus_one

Semantica

Esegue il logaritmo a livello di elemento e un'operazione sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: logp1 da IEEE-754.
  • Per i numeri complessi: logaritmo complesso più uno.
  • Per i tipi quantizzati: dequantize_op_quantize(log_plus_one, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Altri esempi

logistica

Semantica

Esegue un'operazione logistica basata sugli elementi sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: division(1, addition(1, exp(-x))) da IEEE-754.
  • Per i numeri complessi: logistica complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(logistic, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Altri esempi

mappa

Semantica

Applica una funzione di mappa computation a inputs lungo il dimensions e genera un tensore result.

In forma più formale, result[result_index] = computation(inputs...[result_index]).

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1-C4)
(I2) dimensions Costante del tensore unidimensionale di tipo si64 (C3)
(I3) computation funzione (C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1), (C4)

Vincoli

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation ha il tipo (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> dove Ei = element_type(inputs[i]) e E' = element_type(result).

Esempi

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Altri esempi

massima

Semantica

Esegue l'operazione massima a livello di elemento sui tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: OR logico.
  • Per i numeri interi: massimo numero intero.
  • Per i valori in virgola mobile: maximum da IEEE-754.
  • Per i numeri complessi: valore lessicografico massimo per la coppia (real, imaginary). L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo in programma di rimuovere il supporto per i numeri complessi per questa operazione (#560).
  • Per i tipi quantizzati:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1)
(I2) rhs tensore o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Esempi

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Altri esempi

minima

Semantica

Esegue un'operazione minima per elemento sui tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: AND logico.
  • Per i numeri interi: minimo numero intero.
  • Per i valori in virgola mobile: minimum da IEEE-754.
  • Per i numeri complessi: il valore minimo lessicografico per la coppia (real, imaginary). L'imposizione di un ordinamento sui numeri complessi comporta una semantica sorprendente, quindi in futuro abbiamo in programma di rimuovere il supporto per i numeri complessi per questa operazione (#560).
  • Per i tipi quantizzati:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1)
(I2) rhs tensore o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Esempi

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Altri esempi

moltiplicazione

Semantica

Esegue un prodotto a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: AND logico.
  • Per i numeri interi: moltiplicazione di numeri interi.
  • Per i valori in virgola mobile: multiplication da IEEE-754.
  • Per i numeri complessi: moltiplicazione complessa.
  • Per i tipi quantizzati:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore o tensore quantizzato per tensore (C1)
(I2) rhs tensore o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Altri esempi

nega

Semantica

Esegue la negazione a livello di elemento del tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i numeri interi firmati: negazione di numeri interi.
  • Per numeri interi senza segno: bitcast in numeri interi con segno, negazione di numeri interi e bitcast in valore intero senza segno.
  • Per i valori in virgola mobile: negate da IEEE-754.
  • Per i numeri complessi: negazione complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(negate, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Altri esempi

non

Semantica

Esegue NOT a livello di elemento del tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: NOT logico.
  • Per i numeri interi: NOT a bit a bit.

Argomenti

Nome Tipo Vincoli
operand tensore di tipo booleano o numero intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo booleano o numero intero (C1)

Vincoli

  • (C1) type(operand) = type(result).

Esempi

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

Altri esempi

optimization_barrier

Semantica

Garantisce che le operazioni che producono operand vengano eseguite prima di qualsiasi operazione che dipende da result e impedisce alle trasformazioni del compilatore di spostare le operazioni oltre la barriera. A parte questo, l'operazione è un'identità, ad esempio result = operand.

Argomenti

Nome Tipo Vincoli
operand numero variadic di tensori, tensori quantizzati per tensore o token (C1)

Output

Nome Tipo Vincoli
result numero variadic di tensori, tensori quantizzati per tensore o token (C1)

Vincoli

  • (C1) type(operand...) = type(result...).

Esempi

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Altri esempi

o

Semantica

Esegue l'operatore OR a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: OR logico.
  • Per i numeri interi: OR bit a bit.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di numero intero o di tipo booleano (C1)
(I2) rhs tensore di numero intero o di tipo booleano (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero o di tipo booleano (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Altri esempi

Outfeed

Semantica

Scrive inputs nell'outfeed e produce un token result.

La semantica del campo outfeed_config è definita dall'implementazione.

Input

Etichetta Nome Tipo
(I1) inputs numero variadi di tensori o tensori quantizzati
(I2) token token
(I3) outfeed_config costante di tipo string

Output

Nome Tipo
result token

Esempi

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Altri esempi

tappetino

Semantica

Espande operand riempiendo il tensore e tra gli elementi del tensore con il valore padding_value specificato.

edge_padding_low e edge_padding_high specificano la quantità di spaziatura interna aggiunta rispettivamente al limite inferiore (accanto all'indice 0) e alla fascia alta (accanto all'indice più alto) di ogni dimensione. La quantità di spaziatura interna può essere negativa, dove il valore assoluto di spaziatura interna indica il numero di elementi da rimuovere dalla dimensione specificata.

interior_padding specifica la quantità di spaziatura interna aggiunta tra due elementi qualsiasi in ogni dimensione che non può essere negativa. La spaziatura interna interna viene eseguita prima di quella sul bordo, in modo che quella negativa sui bordi rimuoverà gli elementi dall'operando con imbottitura interna.

Più formalmente, per result[result_index] si intende:

  • operand[operand_index] se result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value in caso contrario.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C2), (C4)
(I2) padding_value tensore 0-dimensionale o tensore quantizzato per tensore (C1)
(I3) edge_padding_low Costante del tensore unidimensionale di tipo si64 (C1), (C4)
(I4) edge_padding_high Costante del tensore unidimensionale di tipo si64 (C1), (C4)
(I5) interior_padding Costante del tensore unidimensionale di tipo si64 (C2-C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C3-C6)

Vincoli

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Esempi

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Altri esempi

partition_id

Semantica

Produce partition_id del processo attuale.

Output

Nome Tipo
result Tensore 0-dimensionale di tipo ui32

Esempi

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Altri esempi

popcnt

Semantica

Esegue il conteggio a livello di elemento del numero di bit impostato nel tensore operand e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero (C1)

Vincoli

  • (C1) type(operand) = type(result).

Esempi

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Altri esempi

potenza

Semantica

Esegue l'esponenziale a livello di elemento del tensore lhs per il tensore rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i numeri interi: elevazione dei numeri interi.
  • Per i valori in virgola mobile: pow da IEEE-754.
  • Per i numeri complessi: esponenziale complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(power, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)
(I2) rhs tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Altri esempi

reale

Semantica

Estrae la parte reale, a livello di elemento, da operand e produce un tensore result. In modo più formale, per ogni elemento x: real(x) = is_complex(x) ? real_part(x) : x.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso (C1), (C2)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile (C1), (C2)

Vincoli

  • (C1) shape(result) = shape(operand).
  • (C2) Per element_type(result) si intende:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • element_type(operand) in caso contrario.

Esempi

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Altri esempi

recv

Semantica

Riceve i dati da un canale con channel_id e produce results.

Se is_host_transfer è true, l'operazione trasferisce i dati dall'host. In caso contrario, i dati vengono trasferiti da un altro dispositivo. Questo significa definito dall'implementazione. Questo flag duplica le informazioni fornite in channel_type, quindi in futuro prevediamo di conservarne solo uno (#666).

results è costituito da valori di payload che vengono prima e un token che termina per ultimo. In futuro, abbiamo in programma di suddividere il payload e il token in due output separati per migliorare la chiarezza (#670).

Input

Etichetta Nome Tipo Vincoli
(I1) token token (C4)
(I2) channel_id costante di tipo si64
(I3) channel_type enum di DEVICE_TO_DEVICE e HOST_TO_DEVICE (C1)
(I4) is_host_transfer costante di tipo i1 (C1)

Output

Nome Tipo Vincoli
results numero variadi di tensori, tensori quantizzati o token (C2-C4)

Vincoli

  • (C1) Per channel_type si intende:
    • HOST_TO_DEVICE se is_host_transfer = true,
    • DEVICE_TO_DEVICE in caso contrario.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) o is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Esempi

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Altri esempi

reduce

Semantica

Applica una funzione di riduzione body a inputs e init_values lungo il dimensions e produce results tensori.

L'ordine delle riduzioni è definito dall'implementazione, il che significa che body e init_values devono formare un monoide per garantire che l'operazione produca gli stessi risultati per tutti gli input in tutte le implementazioni. Tuttavia, questa condizione non è valida per molte riduzioni comuni. Ad esempio, l'aggiunta in virgola mobile per body e zero per init_values in realtà non formano un monoide perché l'aggiunta in virgola mobile non è associativa.

In forma più formale, results...[j0, ..., jR-1] = reduce(input_slices_converted) dove:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], dove : sono inserite in dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) per un albero binario schedule dove:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule è un albero binario completo definito dall'implementazione il cui attraversamento in ordine è costituito da:
    • Valori input_slices_converted...[index], per tutti i index in index_space(input_slices_converted) nell'ordine lessicografico crescente di index.
    • alternato a un numero definito di implementazione di init_values_converted nelle posizioni definite per l'implementazione.

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1-C4), (C6), (C7)
(I2) init_values numero variadi di tensori 0-dimensionali o di tensori quantizzati per tensore (C2), (C3)
(I3) dimensions Costante del tensore unidimensionale di tipo si64 (C4), (C5), (C7)
(I4) body funzione (C6)

Output

Nome Tipo Vincoli
results numero variadi di tensori o tensori quantizzati per tensore (C3), (C7), (C8)

Vincoli

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body ha il tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dove is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...) tranne per il fatto che le dimensioni di inputs... corrispondenti a dimensions non sono incluse.
  • (C8) element_type(results[i]) = Ei per tutti i i di [0,N).

Esempi

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Altri esempi

reduce_precision

Semantica

Esegue la conversione a livello di elemento di operand in un altro tipo di virgola mobile che utilizza exponent_bits e mantissa_bits, tornando al tipo originale in virgola mobile e produce un tensore output.

In forma più formale:

  • I bit di mantissa del valore originale vengono aggiornati per arrotondare il valore originale al valore più vicino rappresentabile con mantissa_bits utilizzando la semantica roundToIntegralTiesToEven.
  • Quindi, se mantissa_bits è inferiore al numero di bit mantissa del valore originale, questi vengono troncati a mantissa_bits.
  • Quindi, se i bit di esponenti del risultato intermedio non rientrano nell'intervallo fornito da exponent_bits, il risultato intermedio supera l'infinito utilizzando il segno originale o torna a zero utilizzando il segno originale.
  • Per i tipi quantizzati, esegue dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)
(I2) exponent_bits costante di tipo si32 (C2)
(I3) mantissa_bits costante di tipo si32 (C3)

Output

Nome Tipo Vincoli
output tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Esempi

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Altri esempi

reduce_scatter

Semantica

reduce_scatter

All'interno di ogni gruppo di processi nella griglia dei processi StableHLO, esegue la riduzione; utilizzando computations, sui valori del tensore operand di ciascun processo, suddivide il risultato della riduzione in parti scatter_dimension e disperde le parti divisa tra i processi per produrre result.

L'operazione suddivide la griglia del processo StableHLO in process_groups, definita come segue:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Successivamente, all'interno di ogni process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] per tutti i sender in process_group, dove receiver_index = process_group.index(receiver).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C2), (C7), (C8)
(I2) scatter_dimension costante di tipo si64 (C1), (C2), (C8)
(I3) replica_groups Costante del tensore bidimensionale di tipo si64 (C3-C5)
(I4) channel_id costante di tipo si64 (C6)
(I5) use_global_device_ids costante di tipo i1 (C6)
(I6) computation funzione (C7)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C8-C9)

Vincoli

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) Per size(replica_groups) si intende:
    • num_replicas se si utilizza cross_replica.
    • num_replicas se si utilizza cross_replica_and_partition.
    • num_processes se si utilizza flattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Se use_global_device_ids = true, allora channel_id > 0.
  • (C7) computation ha il tipo (tensor<E>, tensor<E>) -> (tensor<E>), dove is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) eccetto:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Esempi

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Altri esempi

reduce_window

Semantica

Applica una funzione di riduzione body a finestre di inputs e init_values e produce results.

Il seguente diagramma mostra come vengono calcolati gli elementi in results... da inputs... utilizzando un esempio concreto.

reduce_window

In termini più formali, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (vedi ridurre) dove:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values numero variadi di tensori 0-dimensionali o di tensori quantizzati per tensore (C1), (C13)
(I3) window_dimensions Costante del tensore unidimensionale di tipo si64 (C4), (C5), (C15)
(I4) window_strides Costante del tensore unidimensionale di tipo si64 (C6), (C7), (C15)
(I5) base_dilations Costante del tensore unidimensionale di tipo si64 (C8), (C9), (C15)
(I6) window_dilations Costante del tensore unidimensionale di tipo si64 (C10), (C11), (C15)
(I7) padding Costante del tensore bidimensionale di tipo si64 (C12), (C15)
(I8) body funzione (C13)

Output

Nome Tipo Vincoli
results numero variadi di tensori o tensori quantizzati per tensore (C1), (C14-C16)

Vincoli

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body ha il tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dove is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows dove:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei per tutti i i di [0,N).

Esempi

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Altri esempi

resto

Semantica

Esegue la parte restante dei tensori del dividendo lhs e il divisore rhs e produce un tensore result.

Più formalmente, il segno del risultato viene preso dal dividendo e il valore assoluto del risultato è sempre inferiore al valore assoluto del divisore. Il resto viene calcolato come lhs - d * rhs, dove d è dato da:

  • Per i numeri interi: stablehlo.divide(lhs, rhs).
  • Per i numeri in virgola mobile: division(lhs, rhs) da IEEE-754 con attributo di arrotondamento roundTowardZero.
  • Per i numeri complessi: da definire (#997).
  • Per i tipi quantizzati:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Per i tipi di elementi con virgola mobile, questa operazione è in contrasto con l'operazione remainder della specifica IEEE-754, in cui d è un valore integrale più vicino al valore esatto di lhs/rhs con legami al valore pari.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore (C1)
(I2) rhs tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero, in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Altri esempi

replica_id

Semantica

Produce replica_id del processo attuale.

Output

Nome Tipo
result Tensore 0-dimensionale di tipo ui32

Esempi

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Altri esempi

rimodellare

Semantica

Esegue la riforma del tensore operand in un tensore result. Concettualmente, equivale a mantenere la stessa rappresentazione canonica, ma modificando potenzialmente la forma, ad esempio da tensor<2x3xf32> a tensor<3x2xf32> o tensor<6xf32>.

Più formalmente, result[result_index] = operand[operand_index] dove result_index e operand_index hanno la stessa posizione nell'ordine lessicografico di index_space(result) e index_space(operand).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C3)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1-C3)

Vincoli

  • (C1) element_type(result) è dato da:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand) ad eccezione del fatto che quantization_dimension(operand) e quantization_dimension(result) potrebbero variare in altro modo.
  • (C2) size(operand) = size(result).
  • (C3) Se is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Esempi

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Altri esempi

inverti

Semantica

Inverte l'ordine degli elementi in operand lungo il valore dimensions specificato e produce un tensore result. In modo più formale, result[result_index] = operand[operand_index] dove:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 se d in dimensions.
  • operand_index[d] = result_index[d] in caso contrario.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1), (C3)
(I2) dimensions Costante del tensore unidimensionale di tipo si64 (C2), (C3)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1), (C3)

Vincoli

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Esempi

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Altri esempi

rng

Semantica

Genera numeri casuali utilizzando l'algoritmo rng_distribution e produce un tensore result di una data forma shape.

Se rng_distribution = UNIFORM, i numeri casuali vengono generati seguendo la distribuzione uniforme nell'intervallo [a, b). Se a >= b, il comportamento non è definito.

Se rng_distribution = NORMAL, i numeri casuali vengono generati seguendo la distribuzione normale con media = a e deviazione standard = b. Se b < 0, il comportamento non è definito.

Il modo esatto in cui vengono generati i numeri casuali è definito nell'implementazione. Ad esempio, potrebbero essere o meno deterministici e possono o meno utilizzare lo stato nascosto.

Nelle conversazioni con molti stakeholder, questa operazione è risultata essere ritirata in modo più efficace, quindi in futuro prevediamo di rimuoverla (#597).

Input

Etichetta Nome Tipo Vincoli
(I1) a Tensore 0-dimensionale di tipo intero, booleano o in virgola mobile (C1), (C2)
(I2) b Tensore 0-dimensionale di tipo intero, booleano o in virgola mobile (C1), (C2)
(I3) shape Costante del tensore unidimensionale di tipo si64 (C3)
(I4) rng_distribution enum di UNIFORM e NORMAL (C2)

Output

Nome Tipo Vincoli
result tensore di tipo intero, booleano o in virgola mobile (C1-C3)

Vincoli

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Se rng_distribution = NORMAL, allora is_float(a).
  • (C3) shape(result) = shape.

Esempi

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantica

Restituisce un elemento output pieno di bit casuali uniformi e uno stato di output aggiornato output_state utilizzando l'algoritmo del generatore di numeri pseudocasuale rng_algorithm in base a uno stato iniziale initial_state. È garantito che l'output sia una funzione deterministica di initial_state, ma non ne è garantito che sia deterministico tra le implementazioni.

rng_algorithm è uno dei seguenti:

  • DEFAULT: algoritmo definito dall'implementazione.
  • THREE_FRY: variante definita dall'implementazione dell'algoritmo Threefry.*
  • PHILOX: variante definita dall'implementazione dell'algoritmo Philox.*

* Vedi: Salmon et al. SC 2011. Numeri casuali paralleli: facili come 1, 2, 3.

Input

Etichetta Nome Tipo Vincoli
(I1) rng_algorithm enum di DEFAULT, THREE_FRY e PHILOX (C2)
(I2) initial_state Tensore unidimensionale di tipo ui64 (C1), (C2)

Output

Nome Tipo Vincoli
output_state Tensore unidimensionale di tipo ui64 (C1)
output tensore di tipo numero intero o in virgola mobile

Vincoli

  • (C1) type(initial_state) = type(output_state).
  • (C2) Per size(initial_state) si intende:
    • definita nell'implementazione se rng_algorithm = DEFAULT.
    • 2 se rng_algorithm = THREE_FRY.
    • 2 o 3 se rng_algorithm = PHILOX.

Esempi

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantica

Esegue l'arrotondamento a livello di elemento verso il numero intero più vicino, separando i legami da zero, sul tensore operand e produce un tensore result. Implementa l'operazione roundToIntegralTiesToAway dalla specifica IEEE-754. Per i tipi quantiizzati, esegue dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Altri esempi

round_nearest_even

Semantica

Esegue l'arrotondamento in base agli elementi al numero intero più vicino, spezzando i legami con il numero intero pari sul tensore operand e produce un tensore result. Implementa l'operazione roundToIntegralTiesToEven dalla specifica IEEE-754. Per i tipi quantizzati, esegue dequantize_op_quantize(round_nearest_even, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Altri esempi

rsqrt

Semantica

Esegue un'operazione di radice quadrata reciproca a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: rSqrt da IEEE-754.
  • Per i numeri complessi: radice quadrata reciproca complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(rsqrt, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Altri esempi

scatter

Semantica

Genera tensori results che sono uguali a inputs tensori, tranne per il fatto che diverse sezioni specificate da scatter_indices vengono aggiornate con i valori updates utilizzando update_computation.

Il seguente diagramma mostra in che modo gli elementi in updates... vengono mappati sugli elementi in results... utilizzando un esempio concreto. Il diagramma seleziona alcuni indici updates... di esempio e spiega in dettaglio a quali indici results... corrispondono.

scatter

Più formalmente, per tutti i update_index in index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • Per start_index si intende:
    • scatter_indices[si0, ..., :, ..., siN], dove si sono singoli elementi in update_scatter_index e : viene inserito nell'indice index_vector_dim, se index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] in caso contrario.
  • Per d_input in axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] se d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 in caso contrario.
  • Per d_input in axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] se d_input = input_batching_dims[i_batching] e d_start = scatter_indices_batching_dims[i_batching].
    • full_batching_index[d_input] = 0 in caso contrario.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], dove wi sono singoli elementi di update_window_index e 0 è inserito negli indici da inserted_window_dims e input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Detto ciò, results = exec(schedule, inputs), dove:

  • schedule è una permutazione definita dall'implementazione di index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) dove:
    • Se result_index è nei limiti di shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results è una copia di results con results...[result_index] impostato su updated_values....
    • In caso contrario
    • updated_results = results.
  • exec([], results) = results.

Se indices_are_sorted è true, l'implementazione può presumere che i valori scatter_indices siano ordinati rispetto a scatter_dims_to_operand_dims, altrimenti il comportamento non è definito. Più formalmente, per tutti i i1 < i2 di indices(result), full_start_index(i1) <= full_start_index(i2).

Se unique_indices è true, l'implementazione può presumere che tutti gli indici result_index sparsi sono univoci. Se unique_indices è true, ma gli indici di dispersione non sono univoci, il comportamento è indefinito.

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tensore di tipo intero (C4), (C15), (C19), (C22)
(I3) updates numero variadi di tensori o tensori quantizzati per tensore (C3-C6), (C8)
(I4) update_window_dims Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Costante del tensore unidimensionale di tipo si64 (C14-C18)
(I8) scatter_dims_to_operand_dims Costante del tensore unidimensionale di tipo si64 (C19-C21)
(I9) index_vector_dim costante di tipo si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted costante di tipo i1
(I11) unique_indices costante di tipo i1
(I12) update_computation funzione (C23)

Output

Nome Tipo Vincoli
results numero variadi di tensori o tensori quantizzati per tensore (C24-C25)

Vincoli

  • (C1) same(shape(inputs...)).
  • (C2) 'rank(inputs[0]) = dimensione(di_aggiornamento_finestra_dims) + dimensione(dim_finestra_inserted)
    • size(input_batching_dims)".
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) dove:
    • update_scatter_dim_sizes = shape(scatter_indices) tranne che la dimensione di scatter_indices corrispondente a index_vector_dim non è inclusa.
    • update_window_dim_sizes <= shape(inputs[0]) tranne che le dimensioni delle dimensioni in inputs[0] corrispondenti a inserted_window_dims e input_batching_dims non sono incluse.
    • combine posiziona update_scatter_dim_sizes sugli assi corrispondenti a update_scatter_dims e update_window_dim_sizes sugli assi corrispondenti a update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation ha il tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dove is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei per tutti i i di [0,N).

Esempi

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

Altri esempi

select

Semantica

Genera un tensore result in cui ogni elemento viene selezionato dal tensore on_true o on_false in base al valore dell'elemento corrispondente di pred. Più formalmente, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], dove pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Per i tipi quantizzati, esegue dequantize_select_quantize(pred, on_true, on_false, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) pred tensore di tipo i1 (C1)
(I2) on_true tensore o tensore quantizzato per tensore (C1-C2)
(I3) on_false tensore o tensore quantizzato per tensore (C2)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C2)

Vincoli

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Esempi

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Altri esempi

select_and_scatter

Semantica

Distribuisce i valori del tensore source utilizzando scatter in base al risultato di reduce_window del tensore input utilizzando select e produce un tensore result.

Il seguente diagramma mostra come vengono calcolati gli elementi in result da operand e source utilizzando un esempio concreto.

select_and_scatter

In forma più formale:

  • selected_values = reduce_window_without_init(...) con i seguenti input:

    • inputs = [operand].
    • window_dimensions, window_strides e padding utilizzati così come sono.
    • base_dilations = windows_dilations = 1.
    • Per body si intende:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    dove E = element_type(operand) e reduce_window_without_init funzionano esattamente come reduce_window, tranne per il fatto che schedule dell'elemento reduce sottostante (vedi riduzione) non include valori init. Al momento non è specificato cosa succede se la finestra corrispondente non contiene valori (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) dove:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index se selected_values[source_index] ha l'elemento operand di operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1-C4), (C6), (C8-C11)
(I2) source tensore o tensore quantizzato per tensore (C1), (C2)
(I3) init_value tensore 0-dimensionale o tensore quantizzato per tensore (C3)
(I4) window_dimensions Costante del tensore unidimensionale di tipo si64 (C2), (C4), (C5)
(I5) window_strides Costante del tensore unidimensionale di tipo si64 (C2), (C6), (C7)
(I6) padding Costante del tensore bidimensionale di tipo si64 (C2), (C8)
(I7) select funzione (C9)
(I8) scatter funzione (C10)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C11-C12)

Vincoli

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows dove:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select ha il tipo (tensor<E>, tensor<E>) -> tensor<i1>, dove E = element_type(operand).
  • (C10) scatter ha il tipo (tensor<E>, tensor<E>) -> tensor<E> dove is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Esempi

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Altri esempi

Invia

Semantica

Invia inputs a un canale channel_id e produce un token result.

Se is_host_transfer è true, l'operazione trasferisce i dati all'host. In caso contrario, i dati vengono trasferiti su un altro dispositivo. Questo significa definito dall'implementazione. Questo flag duplica le informazioni fornite in channel_type, quindi in futuro prevediamo di conservarne solo uno (#666).

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati
(I2) token token
(I3) channel_id costante di tipo si64
(I4) channel_type enum di DEVICE_TO_DEVICE e DEVICE_TO_HOST (C1)
(I5) is_host_transfer costante di tipo i1 (C1)

Output

Nome Tipo
result token

Vincoli

  • (C1) Per channel_type si intende:
    • DEVICE_TO_HOST se is_host_transfer = true,
    • DEVICE_TO_DEVICE in caso contrario.

Esempi

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Altri esempi

shift_left

Semantica

Esegue un'operazione di spostamento a sinistra degli elementi sul tensore lhs per un numero di bit rhs e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo intero (C1)
(I2) rhs tensore di tipo intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Altri esempi

shift_right_arithmetic

Semantica

Esegue un'operazione di spostamento a destra aritmetico degli elementi sul tensore lhs per rhs numero di bit e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo intero (C1)
(I2) rhs tensore di tipo intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Altri esempi

shift_right_logical

Semantica

Esegue un'operazione di spostamento a destra logica a livello di elemento sul tensore lhs per un numero di bit rhs e produce un tensore result.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo intero (C1)
(I2) rhs tensore di tipo intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo intero (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Altri esempi

firmare

Semantica

Restituisce il segno dell'elemento operand a livello di elemento e produce un tensore result. Più formalmente, per ogni elemento x, la semantica può essere espressa utilizzando la sintassi Python come segue:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Per i tipi quantizzati, esegue dequantize_op_quantize(sign, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di numero intero con segno, in virgola mobile o di tipo complesso oppure tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero con segno, in virgola mobile o di tipo complesso oppure tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Altri esempi

seno

Semantica

Esegue un'operazione seno a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: sin da IEEE-754.
  • Per i numeri complessi: seno complesso.
  • Per i tipi quantizzati: dequantize_op_quantize(sine, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Altri esempi

sezione

Semantica

Estrae una sezione da operand utilizzando indici iniziali calcolati in modo statico e produce un tensore result. start_indices contiene gli indici iniziali della sezione per ogni dimensione, limit_indices contiene gli indici finali (esclusivi) della sezione per ogni dimensione e strides contiene i passi per ogni dimensione.

Più formalmente, result[result_index] = operand[operand_index] dove operand_index = start_indices + result_index * strides.

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato per tensore (C1-C3), (C5)
(I2) start_indices Costante del tensore unidimensionale di tipo si64 (C2), (C3), (C5)
(I3) limit_indices Costante del tensore unidimensionale di tipo si64 (C2), (C3), (C5)
(I4) strides Costante del tensore unidimensionale di tipo si64 (C2), (C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato per tensore (C1), (C5)

Vincoli

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Esempi

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Altri esempi

ordinare

Semantica

Ordina le sezioni unidimensionali di inputs nella dimensione dimension, in base a un comparator e produce results.

A differenza di input simili in altre operazioni, dimension consente valori negativi, con la semantica descritta di seguito. In futuro, questa pratica potrebbe non essere consentita per motivi di coerenza (#1377).

Se is_stable è true, l'ordinamento è stabile, ovvero l'ordine relativo degli elementi considerati uguali dal criterio di confronto. Per il caso in cui è presente un singolo input, due elementi e1 e e2 vengono considerati uguali dal comparatore se e solo se comparator(e1, e2) = comparator(e2, e1) = false. Vedi la formalizzazione riportata di seguito per come si applica a più input.

Più formalmente, per tutti i result_index in index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] dove riN sono singoli elementi in result_index e : è inserito in adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • dove sort ordina una sezione unidimensionale in ordine non decrescente, aspettando che comparator_together restituisca true se l'argomento a sinistra è inferiore al secondo argomento a destra.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Input

Etichetta Nome Tipo Vincoli
(I1) inputs numero variadi di tensori o tensori quantizzati per tensore (C1-C5)
(I2) dimension costante di tipo si64 (C4)
(I3) is_stable costante di tipo i1
(I4) comparator funzione (C5)

Output

Nome Tipo Vincoli
results numero variadi di tensori o tensori quantizzati per tensore (C2), (C3)

Vincoli

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, dove R = rank(inputs[0]).
  • (C5) comparator ha il tipo (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, dove Ei = element_type(inputs[i]).

Esempi

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Altri esempi

sqrt

Semantica

Esegue un'operazione di radice quadrata degli elementi sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: squareRoot da IEEE-754.
  • Per i numeri complessi: radice quadrata complessa.
  • Per i tipi quantizzati: dequantize_op_quantize(sqrt, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Altri esempi

subtract

Semantica

Esegue la sottrazione a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i numeri interi: sottrazione di numeri interi.
  • Per i valori in virgola mobile: subtraction da IEEE-754.
  • Per i numeri complessi: sottrazione complessa.
  • Per i tipi quantizzati:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)
(I2) rhs tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di numero intero, in virgola mobile o di tipo complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Esempi

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Altri esempi

Tanh

Semantica

Esegue un'operazione di tangente iperbolica a livello di elemento sul tensore operand e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori in virgola mobile: tanh da IEEE-754.
  • Per i numeri complessi: tangente iperbolica complessa.
  • Per i tipi quantizzati:
    • dequantize_op_quantize(tanh, operand, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_type(operand) = baseline_type(result).

Esempi

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Altri esempi

trasponi

Semantica

Migliora le dimensioni del tensore operand utilizzando permutation e produce un tensore result. Più formalmente, result[result_index] = operand[operand_index] dove result_index[d] = operand_index[permutation[d]].

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore o tensore quantizzato (C1-C4)
(I2) permutation Costante del tensore unidimensionale di tipo si64 (C2-C4)

Output

Nome Tipo Vincoli
result tensore o tensore quantizzato (C1), (C3-C4)

Vincoli

  • (C1) element_type(result) è dato da:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand) ad eccezione del fatto che quantization_dimension(operand) e quantization_dimension(result) potrebbero variare in altro modo.
  • (C2) permutation è una permutazione di range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Se is_per_axis_quantized(result), allora quantization_dimension(operand) = permutation(quantization_dimension(result)).

Esempi

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Altri esempi

triangular_solve

Semantica

Risolve batch di sistemi di equazioni lineari con matrici a coefficiente triangolare inferiore o superiore.

Più formalmente, dati a e b, result[i0, ..., iR-3, :, :] è la soluzione per op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] quando left_side è true o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] quando left_side è false, risolvendo per la variabile x dove op(a) è determinato da transpose_a, che può essere uno dei seguenti:

  • NO_TRANSPOSE: esegui l'operazione utilizzando a così com'è.
  • TRANSPOSE: esegui un'operazione sulla trasposizione di a.
  • ADJOINT: esegui un'operazione sulla trasposizione coniugata di a.

I dati di input vengono letti solo dal triangolo inferiore di a, se lower è true o dal triangolo superiore di a. I dati di output vengono restituiti nello stesso triangolo; i valori nell'altro triangolo sono definiti dall'implementazione.

Se unit_diagonal è true, l'implementazione può presumere che gli elementi diagonali di a siano uguali a 1, altrimenti il comportamento non è definito.

Per i tipi quantizzati, esegue dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) a tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1-C3)
(I2) b tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1-C4)
(I3) left_side costante di tipo i1 (C3)
(I4) lower costante di tipo i1
(I5) unit_diagonal costante di tipo i1
(I6) transpose_a enum di NO_TRANSPOSE, TRANSPOSE e ADJOINT

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile o complesso o tensore quantizzato per tensore (C1)

Vincoli

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relazione tra shape(a) e shape(b) è definita come segue:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Esempi

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semantica

Genera una tupla result dai valori val.

Input

Etichetta Nome Tipo Vincoli
(I1) val numero variadic di valori (C1)

Output

Nome Tipo Vincoli
result tuple (C1)

Vincoli

  • (C1) result ha il tipo tuple<E0, ..., EN-1>, dove Ei = type(val[i]).

Esempi

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Altri esempi

uniform_dequantize

Semantica

Esegue la conversione a livello di elemento del tensore quantizzato operand in un tensore in virgola mobile result in base ai parametri di quantizzazione definiti dal tipo operand.

In forma più formale, result = dequantize(operand).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore quantizzato (C1), (C2)

Output

Nome Tipo Vincoli
result tensore di tipo in virgola mobile (C1), (C2)

Vincoli

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Esempi

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantica

Esegue la conversione a livello di elemento del tensore in virgola mobile o del tensore quantizzato operand in un tensore quantizzato result in base ai parametri di quantizzazione definiti dal tipo result.

In modo più formale,

  • Se is_float(operand):
    • result = quantize(operand, type(result)).
  • Se is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Input

Etichetta Nome Tipo Vincoli
(I1) operand tensore di tipo in virgola mobile o quantizzato (C1), (C2)

Output

Nome Tipo Vincoli
result tensore quantizzato (C1), (C2)

Vincoli

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Esempi

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

mentre

Semantica

Genera l'output dall'esecuzione della funzione body 0 o più volte, mentre la funzione cond restituisce true. Più formalmente, la semantica può essere espressa usando la sintassi Python:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Il comportamento di un loop infinito è da definire (#383).

Input

Etichetta Nome Tipo Vincoli
(I1) operand numero variadi di tensori, tensori quantizzati o token (C1-C3)
(I2) cond funzione (C1)
(I3) body funzione (C2)

Output

Nome Tipo Vincoli
results numero variadi di tensori, tensori quantizzati o token (C3)

Vincoli

  • (C1) cond ha il tipo (T0, ..., TN-1) -> tensor<i1>, dove Ti = type(operand[i]).
  • (C2) body ha il tipo (T0, ..., TN-1) -> (T0, ..., TN-1), dove Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Esempi

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Altri esempi

Xor

Semantica

Esegue XOR a livello di elemento di due tensori lhs e rhs e produce un tensore result. A seconda del tipo di elemento, procedi come segue:

  • Per i valori booleani: XOR logico.
  • Per i numeri interi: XOR bit a bit.

Input

Etichetta Nome Tipo Vincoli
(I1) lhs tensore di tipo booleano o numero intero (C1)
(I2) rhs tensore di tipo booleano o numero intero (C1)

Output

Nome Tipo Vincoli
result tensore di tipo booleano o numero intero (C1)

Vincoli

  • (C1) type(lhs) = type(rhs) = type(result).

Esempi

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Altri esempi

Interoperabilità dei dialetti

Al momento, i programmi StableHLO in circolazione a volte contengono operazioni che non sono definite da StableHLO.

Modulo, funzione, chiamata e ritorno

StableHLO utilizza operazioni MLIR upstream per ModuleOp, FuncOp, CallOp e ReturnOp. Ciò è stato fatto per migliorare l'interoperabilità con i macchinari MLIR esistenti, poiché molti pass utili sono scritti mirati a FuncOp e ModuleOp e molte pipeline di compilazione si aspettano che queste operazioni siano presenti. A queste operazioni vengono applicate garanzie di compatibilità completa. Se qualcosa cambia in queste operazioni in modo incompatibile (ovvero la rimozione), gli equivalenti StableHLO verranno aggiunti per preservarne la compatibilità.

CHLO

L'opset CHLO contiene operazioni di livello superiore che si decompongono in StableHLO. Al momento non esistono garanzie di compatibilità per CHLO. Per le garanzie di compatibilità, il pass chlo-legalize-to-stablehlo deve essere utilizzato prima della serializzazione.

Operazioni forme

È un caso d'uso comune nella community utilizzare determinate operazioni dei dialetti MLIR principali nei programmi dinamici StableHLO per eseguire calcoli di forma. Più comunemente, queste includono operazioni di shape dialetto come shape_of o num_elements, operazioni di tensor dialetto come dim o from_elements e il tipo integrato index.

L'opzione Dynamism RFC > O2 indica che queste informazioni non rientrano nell'ambito, tuttavia alcuni tipi di supporto per i tipi index sono inclusi per scopi di interoperabilità. Non esistono garanzie di compatibilità per queste operazioni o questi tipi. Il pass shape-legalize-to-stablehlo può essere utilizzato per convertire queste operazioni in operazioni StableHLO completamente supportate.

Operazioni deprecate

Esistono diverse operazioni StableHLO ereditate da MHLO che sono deprecate e stanno per uscire da StableHLO. Puoi trovare i dettagli completi su queste rimozioni nella pagina StableHLO v1.0 Cleanup #2283. Il problema del tracker per queste deprecazioni è il #2340.

Queste operazioni rientrano in alcune categorie:

  • Categoria "Not in HLO" (non in HLO) delle operazioni StableHLO. Inizialmente facevano parte dell'opset StableHLO, ma in seguito sono state ritenute non adeguate: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (3).
  • Operazioni inutilizzate: queste operazioni potrebbero essere state utili a un certo punto, ma le operazioni erano sottosviluppate oppure le pipeline che le utilizzano sono state riadattate per non richiederle più. Sono inclusi map, tuple (#598), get_tuple_element, rng, confronti tra complex #560 e convolution window_reversal (#1181).

Alcune di queste operazioni possono essere rimosse facilmente dato che possono essere espresse usando le operazioni esistenti (broadcast, create_token, cross-replica-sum, dot,unary_einsum) e verranno rimosse al termine della finestra di compatibilità esistente (6 mesi). Altre sono ancora in fase di analisi per la rimozione (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex confronti, window_reversal). In attesa del feedback della community, queste operazioni verranno rimosse o aggiunte alla specifica con il supporto completo. Finché questi future operazioni non saranno noti, sono garantiti solo 6 mesi di compatibilità.

Attuazione

Esecuzione sequenziale

Un programma StableHLO viene eseguito fornendo valori di input alla funzione main e calcolando i valori di output. I valori di output di una funzione vengono calcolati eseguendo il grafico delle operazioni radicate nell'operazione return corrispondente.

L'ordine di esecuzione è definito dall'implementazione, purché sia in linea con Dataflow, ovvero se le operazioni vengono eseguite prima dei loro utilizzi. In StableHLO, tutte le operazioni con effetto laterale consumano un solo token e ne producono uno (più token possono essere multiplexati in un token tramite after_all), quindi anche l'ordine di esecuzione degli effetti collaterali è allineato al flusso di dati. Ad esempio, nel programma seguente sono possibili due ordini di esecuzione: %0%1%2return e %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Più formalmente, un processo StableHLO è una combinazione di: 1) un programma StableHLO, 2) stati operativi (non ancora eseguiti, già eseguiti) e 3) valori intermedi a cui il processo sta lavorando. Il processo inizia con i valori di input della funzione main, avanza nel grafico delle operazioni che aggiornano gli stati delle operazioni e i valori intermedi e termina con i valori di output. Un'ulteriore formalizzazione è da definire (#484).

Esecuzione in parallelo

I programmi StableHLO possono essere eseguiti in parallelo, organizzati in una griglia di processi 2D di num_replicas da num_partitions, entrambi di tipo ui32.

Nella griglia dei processi StableHLO, num_replicas * num_partitions dei processi StableHLO vengono eseguiti contemporaneamente. Ogni processo ha un process_id = (replica_id, partition_id) univoco, dove replica_id in replica_ids = range(num_replicas) e partition_id in partition_ids = range(num_partitions), entrambi di tipo ui32.

La dimensione della griglia del processo è nota in modo statico per ogni programma (in futuro, prevediamo di renderla una parte esplicita dei programmi StableHLO #650) e la posizione all'interno della griglia del processo è nota in modo statico per ogni processo. Ogni processo ha accesso alla sua posizione all'interno della relativa griglia tramite le operazioni replica_id e partition_id.

All'interno della griglia dei processi, i programmi possono essere tutti uguali (nello stile "Programma singolo, più dati"), possono essere tutti diversi (nello stile "Programma multiplo, più dati") o uno intermedio. In futuro, abbiamo in programma di introdurre il supporto per altre idiomi della definizione di programmi StableHLO paralleli, tra cui GSPMD (#619).

All'interno della griglia dei processi, i processi sono per lo più indipendenti l'uno dall'altro: hanno stati operativi separati, valori di input/intermedi/output separati e la maggior parte delle operazioni viene eseguita separatamente tra i processi, ad eccezione del numero ridotto di operazioni collettive descritte di seguito.

Dato che l'esecuzione della maggior parte delle operazioni utilizza solo valori dello stesso processo, di solito non è ambiguo fare riferimento a questi valori per nome. Tuttavia, quando si descrive la semantica delle operazioni collettive, ciò non è sufficiente e dà origine alla notazione name@process_id per fare riferimento al valore name all'interno di un particolare processo. (Da questo punto di vista, name non qualificato può essere visualizzato come una forma abbreviata di name@(replica_id(), partition_id())).

L'ordine di esecuzione nei processi è definito dall'implementazione, ad eccezione della sincronizzazione introdotta dalla comunicazione point-to-point e dalle operazioni collettive come descritto di seguito.

Comunicazione point-to-point

I processi StableHLO possono comunicare tra loro tramite i canali StableHLO. Un canale è rappresentato da un ID positivo di tipo si64. Attraverso varie operazioni, è possibile inviare valori ai canali e riceverli dai canali.

Un'ulteriore formalizzazione, ad esempio da dove provengono gli ID dei canali, in che modo i programmi dei processi ne vengono a conoscenza e quale tipo di sincronizzazione viene loro introdotto, è da definire (#484).

Comunicazione in streaming

Ogni processo StableHLO ha accesso a due interfacce di inserimento di flussi:

  • Infeed da cui è possibile leggere.
  • Outfeed in cui è possibile scrivere.

A differenza dei canali, che vengono utilizzati per comunicare tra i processi e pertanto presentano processi a entrambi gli scopi, gli annunci infeed e in uscita vengono definiti mediante l'implementazione finale.

Un'ulteriore formalizzazione, ad esempio in che modo la comunicazione in streaming influenza l'ordine di esecuzione e il tipo di sincronizzazione che ha introdotto, è da definire (#484).

Operazioni collettive

In StableHLO sono presenti sei operazioni collettive: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute e reduce_scatter. Tutte queste operazioni suddividono i processi nella griglia di processi di StableHLO in gruppi di processi StableHLO ed eseguono un calcolo congiunto all'interno di ciascun gruppo di processi, indipendentemente dagli altri gruppi di processi.

All'interno di ogni gruppo di processi, le operazioni collettive possono introdurre una barriera di sincronizzazione. Un'ulteriore formalizzazione, ad esempio un'analisi di quando avviene esattamente questa sincronizzazione, come i processi raggiungono questa barriera e cosa succede se non lo fanno, è da definire (#484).

Se il gruppo di processi prevede una comunicazione tra partizioni, ovvero il gruppo di processi contiene processi con ID partizione diversi, l'esecuzione dell'operazione collettiva richiede un canale, mentre l'operazione collettiva deve fornire un valore channel_id positivo di tipo si64. La comunicazione con replica incrociata non ha bisogno di canali.

I calcoli eseguiti dalle operazioni collettive sono specifici delle singole operazioni e sono descritti nelle singole sezioni operative riportate sopra. Tuttavia, le strategie secondo cui la griglia dei processi viene suddivisa in gruppi di processi vengono condivise tra queste operazioni e descritte in questa sezione. In termini più formali, StableHLO supporta le seguenti quattro strategie.

cross_replica

Solo le comunicazioni con replica incrociata avvengono all'interno di ogni gruppo di processi. Questa strategia prende replica_groups, un elenco di ID di replica, e calcola un prodotto cartesiano replica_groups entro il giorno partition_ids. replica_groups deve avere elementi univoci e coprire tutti i replica_ids. In modo più formale, usando la sintassi Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Ad esempio, per replica_groups = [[0, 1], [2, 3]] e num_partitions = 2, cross_replica produrrà [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Solo le comunicazioni tra partizioni avvengono all'interno di ogni gruppo di processi. Questa strategia prende partition_groups, un elenco di elenchi di ID partizioni, e calcola un prodotto cartesiano di partition_groups in base a replica_ids. partition_groups deve avere elementi univoci e coprire tutti i partition_ids. In modo più formale, usando la sintassi Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Ad esempio, per partition_groups = [[0, 1]] e num_replicas = 4, cross_partition produrrà [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

All'interno di ogni gruppo di processi possono essere eseguite entrambe le comunicazioni con replica e tra partizione. Questa strategia prende replica_groups, un elenco di ID di replica, e calcola i prodotti cartesiani per ogni replica_group per partition_ids. replica_groups deve avere elementi univoci e coprire tutti replica_ids. In modo più formale, usando la sintassi Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Ad esempio, per replica_groups = [[0, 1], [2, 3]] e num_partitions = 2, cross_replica_and_partition produrrà [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Questa strategia prende flattened_id_groups, un elenco di elenchi di ID di processo "appiattiti" sotto forma di replica_id * num_partitions + partition_id, e li trasforma in ID di processo. flattened_id_groups deve avere elementi univoci e coprire tutti i process_ids. In modo più formale, usando la sintassi Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Ad esempio, per flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 e num_partitions = 2, flattened_ids produrrà [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Accuratezza

Al momento, StableHLO non fornisce garanzie sull'accuratezza numerica, ma la situazione potrebbe cambiare in futuro (#1156).

semantica di esecuzione dell'operazione quantiizzata

L'interpretazione delle operazioni StableHLO quantizzate può variare a seconda dei requisiti e delle funzionalità hardware. Ad esempio, alcuni hardware potrebbero scegliere di interpretare le operazioni quantiizzate utilizzando una strategia di "dequantizzazione, esecuzione di operazioni in virgola mobile e infine quantizzazione". Altre effettuano l'intero calcolo con l'aritmetica dei numeri interi. Di conseguenza, l'interpretazione delle operazioni quantiizzate di StableHLO è determinata esclusivamente dall'implementazione specifica. L'interpretazione della quantizzazione ibrida (#1575) dovrebbe basarsi sulla sua semantica, come prescritta nella specifica (tramite 1792).

Errori

I programmi StableHLO sono convalidati attraverso una vasta serie di vincoli per le singole operazioni, il che esclude molte classi di errori prima dell'esecuzione. Tuttavia, le condizioni di errore sono ancora possibili, ad esempio tramite overflow di numeri interi, accessi fuori dai limiti e così via. A meno che non vengano esplicitamente richiamati, tutti questi errori generano un comportamento definito dall'implementazione, ma questo comportamento potrebbe cambiare in futuro (#1157).

Eccezioni punto in virgola mobile

Come eccezione a questa regola, le eccezioni in virgola mobile nei programmi StableHLO hanno un comportamento ben definito. Le operazioni che comportano eccezioni definite dallo standard IEEE-754 (operazione non valida, divisione per zero, overflow, underflow o eccezioni inesatte) producono risultati predefiniti (come definiti nello standard) e continuano l'esecuzione senza aumentare il flag di stato corrispondente, in modo simile alla gestione delle eccezioni raiseNoFlag dallo standard. Le eccezioni per le operazioni non standard (ad es. aritmetica complessa e alcune funzioni trascendentali) sono definite dall'implementazione.

Mancata corrispondenza della forma

StableHLO supporta tensori di forma dinamica. Tuttavia, le forme devono concordare in fase di runtime, altrimenti il comportamento non è definito. StableHLO non fornisce esplicitamente un'operazione che possa affermare che un tensore ha una forma specifica in fase di runtime. La generazione del codice corretto è responsabilità del producer.

Come esempio specifico, il programma riportato di seguito è valido. Tuttavia, in fase di runtime, le forme esatte di %arg0 e %arg1 dovranno essere le stesse, altrimenti il comportamento del programma non è definito:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notazione

Per descrivere la sintassi, questo documento utilizza il sapore ISO modificato della sintassi EBNF (ISO/IEC 14977:1996, Wikipedia), con due modifiche: 1) le regole vengono definite utilizzando ::= anziché =,

2) la concatenazione viene espressa utilizzando la giustapposizione anziché ,.

Per descrivere la semantica (ovvero all'interno delle sezioni "Tipi", "Costanti" e "Operazioni"), utilizziamo formule basate sulla sintassi Python estesa con il supporto per esprimere in modo conciso le operazioni di array, come descritto di seguito. Questo approccio funziona bene per piccoli snippet di codice, ma in rari casi quando sono necessari snippet di codice più grandi, utilizziamo la sintassi Python vanilla, che viene sempre introdotta in modo esplicito.

Formule

Diamo un'occhiata al funzionamento delle formule in base a un esempio della specifica dot_general. Uno dei vincoli per questa operazione è il seguente: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

I nomi utilizzati in questa formula provengono da due fonti: 1) funzioni globali, ovvero dim, 2) definizioni dei membri dell'elemento del programma corrispondente, ovvero gli input lhs, lhs_batching_dimensions, rhs e rhs_batching_dimensions definiti nella sezione "Input" di dot_general.

Come accennato in precedenza, la sintassi di questa formula è basata su Python con alcune estensioni orientate alla concisione. Per dare un senso alla formula, trasformiamola nella sintassi Python vanilla.

A) In queste formule stiamo usando = per rappresentare l'uguaglianza, quindi il primo passaggio per ottenere la sintassi Python è la sostituzione di = con ==, come segue: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Inoltre, queste formule supportano i puntini di sospensione (...) che trasformano le espressioni scalari in espressioni tensori. In breve, f(xs...) significa approssimativamente "per ogni x scalare nel tensore xs, calcola un valore f(x) scalare e poi restituisci tutti questi risultati scalari insieme come risultato tensore". Nella sintassi Python vanilla, la nostra formula di esempio si trasforma in: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Grazie alle ellissi, spesso è possibile evitare di lavorare a livello dei singoli scalari. Tuttavia, in alcuni casi complicati, potresti utilizzare una sintassi semi-informale di livello inferiore come nella formula start_indices[bi0, ..., :, ..., biN] della specifica di gather. Al servizio della concisione, non forniamo un formalismo esatto per la traduzione di questa sintassi in Python vanilla, nella speranza che sia ancora intuitivamente comprensibile caso per caso. Facci sapere se alcune formule specifiche sembrano opache e proveremo a migliorarle.

Noterai inoltre che le formule utilizzano i puntini di sospensione per espandere tutti i tipi di elenchi, inclusi tensori, elenchi di tensori (che, ad esempio, possono derivare da un numero variadic di tensori) e così via. Questa è un'altra area in cui non forniamo un formalismo esatto (ad esempio, gli elenchi non fanno neppure parte del sistema di tipo StableHLO) e ci basiamo sulla comprensibilità intuitiva del tipo StableHLO.

C) L'ultimo mezzo notativo che utilizziamo è la trasmissione implicita. Sebbene l'opset StableHLO non supporti la trasmissione implicita, le formule sì, anche al servizio di concisione. In breve, se viene utilizzato uno scalare in un contesto in cui è previsto un tensore, lo scalare viene trasmesso alla forma prevista.

Per continuare con l'esempio dot_general, ecco un altro vincolo: 0 <= lhs_batching_dimensions < rank(lhs). Come definito nella specifica dot_general, lhs_batching_dimensions è un tensore, tuttavia sia 0 che rank(lhs) sono scalari. Dopo aver applicato la trasmissione implicita, la formula diventerà [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Se applicata a una determinata operazione dot_general, questa formula valuterà un tensore di valori booleani. Quando le formule vengono utilizzate come vincoli, il vincolo viene applicato se la formula restituisce true o un tensore che ha solo elementi true.

Nomi.

Nelle formule, l'ambito lessicale include: 1) funzioni globali, 2) definizioni dei membri,

3) definizioni locali. Di seguito è riportato un elenco delle funzioni globali. L'elenco delle definizioni degli elementi dipende dall'elemento del programma a cui viene applicata la notazione:

  • Per le operazioni, le definizioni dei membri includono i nomi introdotti nelle sezioni "Input" e "Output".
  • Per tutto il resto, le definizioni dei membri includono le parti strutturali dell'elemento del programma, il cui nome è dato dai corrispondenti non terminali EBNF. La maggior parte delle volte, i nomi di queste parti strutturali si ottengono convertendo i nomi dei componenti non terminali in lettere maiuscole e minuscole (ad es. IntegerLiteral => integer_literal), ma a volte vengono abbreviati durante il processo (ad es. QuantizationStorageType => storage_type), nel qual caso vengono introdotti esplicitamente in modo simile alle sezioni "Input" / "Output" nelle specifiche dell'operazione.
  • Inoltre, le definizioni dei membri includono sempre self per fare riferimento all'elemento del programma corrispondente.

Valori

Quando vengono valutate, le formule funzionano con i seguenti tipi di valori: 1) Value (valori effettivi, ad es. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; ne conoscono sempre i tipi), 2) Placeholder (valori futuri, ad es. lhs, rhs o result; i valori effettivi non sono ancora noti, ma sono noti solo i tipi), 3) Type (tipi come definiti nella sezione "Tipi"), 4) Function (funzioni globali ", come definito nella sezione"Tipi).

A seconda del contesto, i nomi possono fare riferimento a valori diversi. Più specificamente, la sezione "Semmantica" per le operazioni (e gli equivalenti per altri elementi del programma) definisce la logica di runtime, pertanto tutti gli input sono disponibili come Value. Al contrario, la sezione "Vincoli" per le operazioni (e gli equivalenti) definisce la logica "tempo di compilazione", ovvero qualcosa che in genere viene eseguito prima del runtime, quindi solo gli input costanti sono disponibili come Value e gli altri input sono disponibili solo come Placeholder.

Nomi. In "Semantica" In "Vincoli"
Funzioni globali Function Function
Input costanti Value Value
Input non costanti Value Placeholder
Output Value Placeholder
Definizioni locali Dipende dalla definizione Dipende dalla definizione

Prendiamo in considerazione un'operazione transpose di esempio:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Per questa operazione, permutation è una costante, quindi è disponibile come Value sia in semantica che in vincoli. Al contrario, operand e result sono disponibili come Value nella semantica, ma solo come Placeholder nei vincoli.

Funzioni

Costruzione dei tipi

Non esistono funzioni che possono essere utilizzate per creare i tipi. Usiamo direttamente la sintassi di tipo perché in genere è più concisa. Ad esempio (tensor<E>, tensor<E>) -> (tensor<E>) anziché function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funzioni sui tipi

  • element_type è definito sui tipi di tensori e sui tipi di tensori quantizzati e restituisce, rispettivamente, la parte TensorElementType o QuantizedTensorElementType dei valori TensorType o QuantizedTensorType corrispondenti.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value è una scorciatoia per is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value è una scorciatoia per is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool controlla se il tipo x può essere promosso al tipo y. Quando x e y sono QuantizedTensorElementType, la promozione viene applicata solo a storage_type. Questa versione specifica della promozione viene attualmente utilizzata nel contesto del calcolo della riduzione (fai riferimento a RFC per ulteriori dettagli).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value è una scorciatoia per is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponibile per tutti i tipi. Ad esempio, is_float(x) restituisce true se x è FloatType. Se x è un valore o un segnaposto, questa funzione è una scorciatoia per is_type_name(type(x)).

  • max_value(x: Type) -> Value restituisce il valore massimo di TensorElementType. Se x non è un TensorElementType, restituisce None.

  • min_value(x: Type) -> Value restituisce il valore minimo possibile di TensorElementType. Se x non è un TensorElementType, restituisce None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponibile per tutte le definizioni dei membri member_name di tutti i tipi. Ad esempio, tensor_element_type(x) restituisce la parte TensorElementType di un valore TensorType corrispondente. Se x è un valore o un segnaposto, questa funzione è una scorciatoia per member_name(type(x)). Se x non è un tipo con un membro appropriato oppure un valore o un segnaposto di questo tipo, restituisce None.

Costruzione dei valori

  • operation_name(*xs: Value | Type) -> Value. Disponibile per tutte le operazioni. Ad esempio, add(lhs, rhs) prende due valori di tensore lhs e rhs e restituisce l'output della valutazione dell'operazione add con questi input. Per alcune operazioni, ad esempio broadcast_in_dim, i tipi di output sono "load-bearing", ossia necessari per valutare un'operazione. In questo caso, la funzione prende questi tipi come argomenti.

Funzioni sui valori

  • Sono disponibili tutti gli operatori e le funzioni di Python. Ad esempio, sia le notazioni di subscription che di slicing di Python sono disponibili per l'indicizzazione in tensori, tensori quantizzati e tuple.

  • to_destination_type(x: Value, destination_type: Type) -> Value è definito sui tensori e restituisce il valore convertito di x in base ai type(x) e destination_type come segue:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Si è in corso una discussione iniziale sull'unione delle operazioni convert, uniform_quantize e uniform_dequantize (#1576). Dopo l'unione non abbiamo bisogno della funzione riportata sopra e possiamo utilizzare il nome dell'operazione per convert.

  • is_nan(x: Value) -> Value è definito sui tensori e restituisce true se tutti gli elementi di x sono NaN o false. Se x non è un tensore, restituisce None.

  • is_sorted(x: Value) -> Value è definito sui tensori e restituisce true se gli elementi di x sono ordinati in ordine crescente rispetto all'ordine lessicografico crescente dei relativi indici o false in caso contrario. Se x non è un tensor, restituisce None.

  • is_unique(x: Value) -> Value è definito sui tensori e restituisce true se x non ha elementi duplicati o false in caso contrario. Se x non è un tensore, restituisce None.

  • member_name(x: Value) -> Any è definito per tutte le definizioni dei membri member_name di tutti i valori. Ad esempio, real_part(x) restituisce la parte RealPart di un valore ComplexConstant corrispondente. Se x non è un valore con un membro appropriato, restituisce None.

  • same(x: Value) -> Value è definito sui tensori e restituisce true se gli elementi di x sono tutti uguali tra loro o false negli altri casi. Se il tensore non ha elementi, viene conteggiato come "tutti uguali tra loro", ovvero la funzione restituisce true. Se x non è un tensore, restituisce None.

  • split(x: Value, num_results: Value, axis: Value) -> Value è definito sui tensori e restituisce num_results sezioni di x lungo l'asse axis. Se x non è un tensore o dim(x, axis) % num_results != 0, restituisce None.

  • is_defined_in_parent_scope(x: Value) -> Value è definito sulle stringhe e restituisce true se x è il nome di una funzione definita nello stesso ambito della funzione padre dell'operazione pertinente.

  • is_namespaced_op_name(x: Value) -> Value è definito sulle stringhe e restituisce true se x è un nome dell'operazione valido, ovvero rispetta la seguente espressione regolare: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Calcoli delle forme

  • axes(x: Value | Placeholder | Type) -> Value è una scorciatoia per range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value è una scorciatoia per shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List è una scorciatoia per list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value è definito sui tensori e restituisce gli indici size(x) per il valore TensorType corrispondente ordinato in ordine lessicografico crescente, ad esempio [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Se x non è un tipo tensore, un tipo di tensore quantizzato oppure un valore o un segnaposto di uno di questi tipi, restituisce None.

  • rank(x: Value | Placeholder | Type) -> Value è una scorciatoia per size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value è definito nella sezione "Funzioni sui tipi" tramite member_name.

  • size(x: Value | Placeholder | Type) -> Value è una scorciatoia per reduce(lambda x, y: x * y, shape(x)).

Calcoli di quantizzazione

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type è una scorciatoia per element_type(baseline_type(x)).

  • baseline_type viene definito sui tipi di tensori e sui tipi di tensori quantizzati e li trasforma in una "base di riferimento", ovvero un tipo con la stessa forma ma con i parametri di quantizzazione del tipo di elemento reimpostati sui valori predefiniti. Questo viene utilizzato come metodo utile per confrontare in modo uniforme sia i tipi di tensori sia quelli di tensori quantizzati, cosa necessaria abbastanza spesso. Per i tipi quantiizzati, ciò consente di confrontare i tipi ignorando i parametri di quantizzazione, ovvero shape, storage_type, expressed_type, storage_min, storage_max e quantization_dimension (per il tipo quantizzato per asse) devono corrispondere tutti, ma scales e zero points potrebbero essere diversi.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize è definito sui tipi di tensori quantizzati e li trasforma in tipi di tensori in virgola mobile. Ciò avviene tramite la conversione di elementi quantizzati che rappresentano valori interi del tipo di archiviazione in valori in virgola mobile corrispondenti del tipo espresso utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize è definito su tipi di tensori in virgola mobile e li trasforma in tipi di tensori quantizzati. Ciò accade tramite la conversione dei valori in virgola mobile del tipo espresso in valori interi corrispondenti del tipo di archiviazione utilizzando il punto zero e la scala associati al tipo di elemento quantizzato.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize viene utilizzato per specificare i calcoli a livello di elemento sui tensori quantiizzati. Dequantizza, ovvero trasforma gli elementi quantizzati nei relativi tipi espressi, esegue un'operazione e poi quantifica, ovvero trasforma i risultati nei relativi tipi di archiviazione. Al momento, questa funzione funziona solo per la quantizzazione per tensore. È in corso la quantizzazione per asse (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op viene utilizzato per specificare la quantizzazione di solo peso per l'operazione ibrida, che accetta lh in virgola mobile e rh in tipi quantistici. Dequantizza gli input quantizzati nei tipi espressi ed esegue il calcolo in virgola mobile. Il tipo di elemento del tensore lhs di float e il tipo espresso di tensore rhs quantizzato devono essere identici.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Calcoli griglia

  • cross_partition(replica_groups: Value) -> Value. Vedi la sezione "cross_replica" sopra.

  • cross_replica(replica_groups: Value) -> Value. Vedi la sezione "cross_replica" sopra.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Vedi la sezione "cross_replica_and_partition" sopra.

  • flattened_ids(replica_groups: Value) -> Value. Vedi la sezione "flattened_id" sopra.

Dinamismo

I valori StableHLO possono avere dimensioni di dimensione dinamiche, ad esempio tensor<?xi64>. Tuttavia, i valori StableHLO non possono avere un numero dinamico di dimensioni (dinamismo non classificato, ad esempio tensor<*xi64>). Gli operandi e i risultati possono utilizzare dimensioni delle dimensioni dinamiche, anche in presenza di vincoli nelle dimensioni. I vincoli verranno verificati in modo statico, se possibile, altrimenti vengono differiti al runtime e le mancate corrispondenze genereranno un comportamento indefinito. Di seguito sono riportati gli esempi.

Mancata corrispondenza della forma per operazioni con elementi unari

Considera il seguente programma per giocattoli:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Questo programma è insolito perché non è comune conoscere la forma del risultato ma non la forma dell'input. Ciononostante, si tratta di un programma StabileHLO valido. Non è possibile convalidare in modo statico l'operazione abs in questo programma, perché la forma esatta dell'operando è sconosciuta. Tuttavia, le forme sono certamente compatibili e questo può essere controllato in modo statico: ? potrebbe risultare 2 in fase di runtime, senza che ci siano problemi. Tuttavia, ? potrebbe essere anche un altro numero intero, nel qual caso il comportamento non è definito.

Tieni presente che se la dimensione di una dimensione è dinamica nel risultato, non può esserci un comportamento indefinito. In effetti, non esistono dimensioni "previste", quindi non possono esserci una mancata corrispondenza.

Mancata corrispondenza della forma per le operazioni con elementi binari

Considera il seguente programma per giocattoli:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

Quando si tratta di operazioni binarie in termini di elementi, le forme degli input e il risultato devono concordare in fase di runtime. Al momento della compilazione, le dimensioni statiche devono essere uguali, altrimenti devono semplicemente essere compatibili. Se qualsiasi dimensione è dinamica negli input, potrebbe esserci un comportamento indefinito in fase di runtime, poiché la dimensione dinamica potrebbe non corrispondere alla dimensione corrispondente nell'altro operando (statica o dinamica). Se tutti gli input sono statici, la presenza o meno del risultato dinamico è irrilevante: le dimensioni note in modo statico verranno controllate in modo statico e le dimensioni dinamiche non impongono vincoli.

Mancata corrispondenza della forma per le operazioni che assumono la forma di output come un operando

Considera il seguente programma per giocattoli:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

I valori nell'operando della forma in fase di runtime devono corrispondere alla forma del risultato, altrimenti il comportamento è indefinito. In altre parole, nel runtime %arg0 deve avere un valore dense<[3, 4]> : tensor<2xi32>. Se l'operando della forma è costante, può essere verificato in modo statico. Se la forma del risultato è completamente dinamica, non possono esserci discrepanze.