Modello dati
I programmi StableHLO sono calcoli su tensori
(matrici n-dimensionali), che, nel modello attuale, vengono implementati utilizzando la
classe Tensor
. La classe di archiviazione sottostante per un oggetto Tensor
,
detail::Buffer
, archivia il mlir::ShapedType
del tensore insieme a un
oggetto mlir::HeapAsmResourceBlob
che rappresenta un blob mutabile di dati tensore disposti come array di byte contigui in
ordine maggiore-minore.
detail::Buffer
oggetti vengono conteggiati per riferimento per semplificare la gestione della memoria.
I singoli elementi di un tensore sono rappresentati utilizzando la classe Element
, che utilizza un'unione discriminata che contiene APInt
, APFloat
o pair<APFloat,APFloat>
per l'archiviazione. L'ultimo viene usato per archiviare
elementi con tipi complessi.
Tensor
dispone delle seguenti API per interagire con i singoli elementi:
Element Tensor::get(llvm::ArrayRef<int64_t> index)
: per estrarre un singolo elemento tensore nell'indice multidimensionaleindex
come oggettoElement
.void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
: per aggiornare un oggettoElement
element
in un tensore nell'indice multidimensionaleindex
.
Come funziona l'interprete
La funzione di immissione per l'interprete è
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
che esegue le seguenti operazioni:
- Tiene traccia degli argomenti SSA di
func
e dei valoriTensor
di runtime associati, forniti inargs
, utilizzando una mappa della tabella di simboli M. - Per ogni operazione in
func
, in ordine SSACFG:- Richiama
eval
sull'operazione. Per ogni operando SSA dell'operazione, estrai il relativo valore di runtime da M da fornire come argomento alla chiamataeval
. - Tiene traccia dei risultati SSA dell'operazione e del valore valutato in M.
- Richiama
Il livello dell'operazione eval
menzionato nel punto (2) è responsabile dell'implementazione della
semantica di esecuzione dell'operazione. Di seguito è riportato un esempio per stablehlo::AddOp
.
Nell'esempio, i singoli elementi dei tensori lhs
e rhs
vengono estratti a coppie
come oggetti Element
che vengono poi aggiunti. Il risultato dell'aggiunta, un oggetto Element
, viene archiviato nel tensore result
finale.
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
Nel complesso, il design dell'interprete è ottimizzato per la leggibilità delle implementazioni delle funzioni eval
per le singole operazioni perché è destinato a fungere da implementazione di riferimento per StableHLO. Ad esempio, anziché definire eval
come funzione per il modello e parametrizzarlo con i tipi di elementi, incapsulamo i dettagli su come vengono gestiti i diversi tipi di elementi in Element::operator+
e così via, semplificando l'implementazione di eval
.
Utilizzare l'interprete per la chiusura continua
Possiamo utilizzare il meccanismo dell'interprete per comprimere le operazioni con valori operandi costanti. Il seguente snippet di codice mostra un'idea dell'implementazione per il piegamento di stablehlo::AddOp
con operandi digitati in virgola mobile:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(element.getValue().cast<FloatAttr>().getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
Al momento non stiamo lavorando attivamente all'integrazione dell'interprete nella piegatura costante perché non abbiamo intenzione di implementare la cartella per StableHLO.
Tuttavia, in futuro, abbiamo in programma di utilizzare l'interprete per la piegatura costante in MHLO, per migliorare l'ergonomia dello snippet di codice riportato sopra (ad esempio, potremmo avere una funzione helper che raggruppa gli operandi costanti in oggetti Tensor
e decompone i risultati di Tensor
in OpFoldResult
).
Test dell'interprete StableHLO
L'interprete prende come input (A) un programma StableHLO e (B) i valori dei dati da fornire al programma e genera i valori dei dati di output, che vengono abbinati ai valori dei dati previsti forniti dall'utente. I valori dei dati (B) sono
hardcoded nel programma stesso utilizzando operazioni stablehlo.constant
. L'interprete valuta il programma di input. Gli output dell'operazione sottoposta a test vengono controllati tramite controlli (ad es. check.expect_eq
, check.expect_almost_eq
), come mostrato di seguito. check.expect_eq
e check.expect_eq_const
verificano l'uguaglianza a livello di bit
per qualsiasi tipo supportato, mentre check.expect_almost_eq
e
check.expect_almost_eq_const
verificano l'uguaglianza quasi all'interno di una tolleranza, come spiegato nelle linee guida di test (G6), per i tipi in virgola mobile e complessi.
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
Un'utilità di test stablehlo-translate --interpret
(codice) è responsabile dell'analisi del programma, interpretando ogni funzione, comprese le operazioni che la costituiscono. Disponiamo di una suite di test dedicata, composta da diversi test che esercitano vari comportamenti di runtime, per ogni operazione StableHLO. I test sono disponibili qui (ad es. interpret_*.mlir).
Linee guida per l'esecuzione dei test
(G1) Dobbiamo testare tutti i tipi supportati per ogni operazione?
Per decidere possiamo utilizzare una combinazione delle seguenti regole:
Durante l'implementazione di un'operazione, se esiste un codice nella funzione
eval
corrispondente per gestire un determinato tipo, è imperativo disporre di test per coprire quel tipo. Ad esempio, per l'operazioneadd
esiste un codice esclusivo per gestire tipi interi, booleani, a virgola mobile e complessi, quindi è necessario un test per ogni categoria di tipo.Se un insieme di tipi viene gestito in modo uniforme nella funzione
eval
corrispondente, dovrebbe essere sufficiente un unico test per tutti questi tipi. Ad esempio, per l'operazioneadd
, tutte le varianti di tipi interi (si4
,u4
,si8
,u8
e così via) vengono gestite allo stesso modo utilizzando le APIllvm::APInt
, perciò possiamo saltare l'aggiunta dei test per ciascuna di queste varianti e invece aggiungere un singolo test rappresentativo. Per evitare ambiguità nella selezione del rappresentante, è necessario utilizzare le seguenti linee guida:- Se tutti i tipi, gestiti in modo uniforme, hanno lo stesso tipo primitivo (se sono tutti di tipo intero, in virgola mobile o complesso), scegli quello con larghezza in bit massima.
- Se tutti i tipi, gestiti in modo uniforme, hanno un mix di tipi primitivi, scegli quello con il seguente tipo primitivo, in ordine decrescente di preferenza: numero intero, in virgola mobile, booleano, complesso.
(G2) Come decidiamo il numero di test necessari per coprire il comportamento di un'operazione?
L'obiettivo è coprire in modo esauriente la logica dell'interprete per l'operazione (ovvero tutti i casi limite dell'implementazione) con un numero minimo di test. Ridurre al minimo il numero di test è importante per la manutenibilità. Minore è il numero di test, più è facile esaminarli e verificare che coprano in modo completo l'operazione. Di conseguenza, ci aspettiamo che la maggior parte delle operazioni più semplici finisca per avere un solo test. Se, per qualche buona ragione, una copertura completa non è praticabile, non è un problema fermarsi a >= 90%. Questo verrà deciso caso per caso durante la revisione delle richieste di pull.
(G3) Che ne dici di aggiungere test per l'infrastruttura dell'interprete?
L'infrastruttura dell'interprete è per lo più semplice e può essere aggiunta alla nostra base di affidabilità. L'unica parte non banale è il modo in cui i vari tipi vengono pacchettizzati
e decompressi dallo spazio di archiviazione sottostante dell'interprete. Come discusso in (G1), testeremo solo i tipi di operazioni gestiti in modo diverso. Di conseguenza, è possibile che il codice di pacchettizzazione/scomposizione, corrispondente a diverse varianti dei tipi di numeri interi/in virgola mobile, non venga completamente coperto durante il test. Per garantire la copertura completa, possiamo scegliere un'operazione come constant
che supporti tutti i tipi di elementi StableHLO e scrivere test esaustivi.
(G4) Se l'implementazione di un'operazione dipende da altre operazioni, dobbiamo scrivere i test per queste ultime?
No. Ad esempio, l'implementazione di batch_norm_grad
può basarsi su divide
, subtract
, multiply
e altri. Dovremmo evitare di testare le ultime operazioni
durante il test della prima.
(G5) Dovremmo scrivere test per esercitare i comportamenti definiti / non definiti nell'implementazione?
Non dobbiamo scrivere test che esercitano i comportamenti definiti o non definiti dell'implementazione. I test che applicano comportamenti definiti dall'implementazione mostrano un comportamento locale dell'interprete che non deve essere generalizzato. I test che esercitano un comportamento indefinito non contribuiscono alla comprensione del comportamento dell'operazione.
(G6) Durante la scrittura dei test per i tipi con virgola mobile, con quale precisione deve essere specificato il risultato previsto nei controlli?
Per le operazioni elementari (addizione, sottrazione, moltiplicazione, divisione e quadrato), un'implementazione conforme alla specifica IEEE dovrebbe fornire un risultato arrotondato entro 0,5 ULP del risultato matematicamente esatto. Detto questo, possiamo immaginare con sicurezza che il risultato previsto proveniente da queste operazioni sia
di un valore massimo di 1 ULP. Tuttavia, potrebbe non funzionare per le funzioni trascendenti (sine
, cosine
e così via) per le quali le garanzie di precisione sono definite dall'implementazione (razionali).
L'attuale implementazione utilizza un valore di tolleranza "universale" pari a 0,0001. Il seguente esempio mostra la tolleranza di cui sopra in azione.
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
Questo è solo il primo passo per testare l'accuratezza numerica delle operazioni StableHLO. Al momento, questa è un'area sottospecificata delle specifiche StableHLO e c'è un lavoro continuo per risolverlo #1156 in base alla nostra esperienza nell'utilizzo di StableHLO nella pratica e al feedback degli stakeholder. Man mano che questa procedura procede, aggiorneremo l'infrastruttura di conseguenza.
(G7) Qualcosa in merito allo stile di programmazione dei test?
- Assicurati di utilizzare il nome effettivo degli ingressi/uscite anziché impostare come valore predefinito i valori SSA (ad es. %0, %1 e così via)
- Assicurati che per i test venga utilizzato un formato stampato, se esistente.
(G8) Devo includere l'esempio già fornito nelle specifiche? Sì (per la completezza dei test).