Interprete

Modello dati

I programmi StableHLO sono calcoli su tensori (matrici n-dimensionali), che, nel modello attuale, vengono implementati utilizzando la classe Tensor. La classe di archiviazione sottostante per un oggetto Tensor, detail::Buffer, archivia il mlir::ShapedType del tensore insieme a un oggetto mlir::HeapAsmResourceBlob che rappresenta un blob mutabile di dati tensore disposti come array di byte contigui in ordine maggiore-minore. detail::Buffer oggetti vengono conteggiati per riferimento per semplificare la gestione della memoria.

I singoli elementi di un tensore sono rappresentati utilizzando la classe Element, che utilizza un'unione discriminata che contiene APInt, APFloat o pair<APFloat,APFloat> per l'archiviazione. L'ultimo viene usato per archiviare elementi con tipi complessi.

Tensor dispone delle seguenti API per interagire con i singoli elementi:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): per estrarre un singolo elemento tensore nell'indice multidimensionale index come oggetto Element.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: per aggiornare un oggetto Element element in un tensore nell'indice multidimensionale index.

Come funziona l'interprete

La funzione di immissione per l'interprete è

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

che esegue le seguenti operazioni:

  1. Tiene traccia degli argomenti SSA di func e dei valori Tensor di runtime associati, forniti in args, utilizzando una mappa della tabella di simboli M.
  2. Per ogni operazione in func, in ordine SSACFG:
    • Richiama eval sull'operazione. Per ogni operando SSA dell'operazione, estrai il relativo valore di runtime da M da fornire come argomento alla chiamata eval.
    • Tiene traccia dei risultati SSA dell'operazione e del valore valutato in M.

Il livello dell'operazione eval menzionato nel punto (2) è responsabile dell'implementazione della semantica di esecuzione dell'operazione. Di seguito è riportato un esempio per stablehlo::AddOp. Nell'esempio, i singoli elementi dei tensori lhs e rhs vengono estratti a coppie come oggetti Element che vengono poi aggiunti. Il risultato dell'aggiunta, un oggetto Element, viene archiviato nel tensore result finale.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

Nel complesso, il design dell'interprete è ottimizzato per la leggibilità delle implementazioni delle funzioni eval per le singole operazioni perché è destinato a fungere da implementazione di riferimento per StableHLO. Ad esempio, anziché definire eval come funzione per il modello e parametrizzarlo con i tipi di elementi, incapsulamo i dettagli su come vengono gestiti i diversi tipi di elementi in Element::operator+ e così via, semplificando l'implementazione di eval.

Utilizzare l'interprete per la chiusura continua

Possiamo utilizzare il meccanismo dell'interprete per comprimere le operazioni con valori operandi costanti. Il seguente snippet di codice mostra un'idea dell'implementazione per il piegamento di stablehlo::AddOp con operandi digitati in virgola mobile:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

Al momento non stiamo lavorando attivamente all'integrazione dell'interprete nella piegatura costante perché non abbiamo intenzione di implementare la cartella per StableHLO. Tuttavia, in futuro, abbiamo in programma di utilizzare l'interprete per la piegatura costante in MHLO, per migliorare l'ergonomia dello snippet di codice riportato sopra (ad esempio, potremmo avere una funzione helper che raggruppa gli operandi costanti in oggetti Tensor e decompone i risultati di Tensor in OpFoldResult).

Test dell'interprete StableHLO

L'interprete prende come input (A) un programma StableHLO e (B) i valori dei dati da fornire al programma e genera i valori dei dati di output, che vengono abbinati ai valori dei dati previsti forniti dall'utente. I valori dei dati (B) sono hardcoded nel programma stesso utilizzando operazioni stablehlo.constant. L'interprete valuta il programma di input. Gli output dell'operazione sottoposta a test vengono controllati tramite controlli (ad es. check.expect_eq, check.expect_almost_eq), come mostrato di seguito. check.expect_eq e check.expect_eq_const verificano l'uguaglianza a livello di bit per qualsiasi tipo supportato, mentre check.expect_almost_eq e check.expect_almost_eq_const verificano l'uguaglianza quasi all'interno di una tolleranza, come spiegato nelle linee guida di test (G6), per i tipi in virgola mobile e complessi.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

Un'utilità di test stablehlo-translate --interpret (codice) è responsabile dell'analisi del programma, interpretando ogni funzione, comprese le operazioni che la costituiscono. Disponiamo di una suite di test dedicata, composta da diversi test che esercitano vari comportamenti di runtime, per ogni operazione StableHLO. I test sono disponibili qui (ad es. interpret_*.mlir).

Linee guida per l'esecuzione dei test

(G1) Dobbiamo testare tutti i tipi supportati per ogni operazione?

Per decidere possiamo utilizzare una combinazione delle seguenti regole:

  1. Durante l'implementazione di un'operazione, se esiste un codice nella funzione eval corrispondente per gestire un determinato tipo, è imperativo disporre di test per coprire quel tipo. Ad esempio, per l'operazione add esiste un codice esclusivo per gestire tipi interi, booleani, a virgola mobile e complessi, quindi è necessario un test per ogni categoria di tipo.

  2. Se un insieme di tipi viene gestito in modo uniforme nella funzione eval corrispondente, dovrebbe essere sufficiente un unico test per tutti questi tipi. Ad esempio, per l'operazione add, tutte le varianti di tipi interi (si4, u4, si8, u8 e così via) vengono gestite allo stesso modo utilizzando le API llvm::APInt, perciò possiamo saltare l'aggiunta dei test per ciascuna di queste varianti e invece aggiungere un singolo test rappresentativo. Per evitare ambiguità nella selezione del rappresentante, è necessario utilizzare le seguenti linee guida:

    • Se tutti i tipi, gestiti in modo uniforme, hanno lo stesso tipo primitivo (se sono tutti di tipo intero, in virgola mobile o complesso), scegli quello con larghezza in bit massima.
    • Se tutti i tipi, gestiti in modo uniforme, hanno un mix di tipi primitivi, scegli quello con il seguente tipo primitivo, in ordine decrescente di preferenza: numero intero, in virgola mobile, booleano, complesso.

(G2) Come decidiamo il numero di test necessari per coprire il comportamento di un'operazione?

L'obiettivo è coprire in modo esauriente la logica dell'interprete per l'operazione (ovvero tutti i casi limite dell'implementazione) con un numero minimo di test. Ridurre al minimo il numero di test è importante per la manutenibilità. Minore è il numero di test, più è facile esaminarli e verificare che coprano in modo completo l'operazione. Di conseguenza, ci aspettiamo che la maggior parte delle operazioni più semplici finisca per avere un solo test. Se, per qualche buona ragione, una copertura completa non è praticabile, non è un problema fermarsi a >= 90%. Questo verrà deciso caso per caso durante la revisione delle richieste di pull.

(G3) Che ne dici di aggiungere test per l'infrastruttura dell'interprete?

L'infrastruttura dell'interprete è per lo più semplice e può essere aggiunta alla nostra base di affidabilità. L'unica parte non banale è il modo in cui i vari tipi vengono pacchettizzati e decompressi dallo spazio di archiviazione sottostante dell'interprete. Come discusso in (G1), testeremo solo i tipi di operazioni gestiti in modo diverso. Di conseguenza, è possibile che il codice di pacchettizzazione/scomposizione, corrispondente a diverse varianti dei tipi di numeri interi/in virgola mobile, non venga completamente coperto durante il test. Per garantire la copertura completa, possiamo scegliere un'operazione come constant che supporti tutti i tipi di elementi StableHLO e scrivere test esaustivi.

(G4) Se l'implementazione di un'operazione dipende da altre operazioni, dobbiamo scrivere i test per queste ultime?

No. Ad esempio, l'implementazione di batch_norm_grad può basarsi su divide, subtract, multiply e altri. Dovremmo evitare di testare le ultime operazioni durante il test della prima.

(G5) Dovremmo scrivere test per esercitare i comportamenti definiti / non definiti nell'implementazione?

Non dobbiamo scrivere test che esercitano i comportamenti definiti o non definiti dell'implementazione. I test che applicano comportamenti definiti dall'implementazione mostrano un comportamento locale dell'interprete che non deve essere generalizzato. I test che esercitano un comportamento indefinito non contribuiscono alla comprensione del comportamento dell'operazione.

(G6) Durante la scrittura dei test per i tipi con virgola mobile, con quale precisione deve essere specificato il risultato previsto nei controlli?

Per le operazioni elementari (addizione, sottrazione, moltiplicazione, divisione e quadrato), un'implementazione conforme alla specifica IEEE dovrebbe fornire un risultato arrotondato entro 0,5 ULP del risultato matematicamente esatto. Detto questo, possiamo immaginare con sicurezza che il risultato previsto proveniente da queste operazioni sia di un valore massimo di 1 ULP. Tuttavia, potrebbe non funzionare per le funzioni trascendenti (sine, cosine e così via) per le quali le garanzie di precisione sono definite dall'implementazione (razionali).

L'attuale implementazione utilizza un valore di tolleranza "universale" pari a 0,0001. Il seguente esempio mostra la tolleranza di cui sopra in azione.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

Questo è solo il primo passo per testare l'accuratezza numerica delle operazioni StableHLO. Al momento, questa è un'area sottospecificata delle specifiche StableHLO e c'è un lavoro continuo per risolverlo #1156 in base alla nostra esperienza nell'utilizzo di StableHLO nella pratica e al feedback degli stakeholder. Man mano che questa procedura procede, aggiorneremo l'infrastruttura di conseguenza.

(G7) Qualcosa in merito allo stile di programmazione dei test?

  1. Assicurati di utilizzare il nome effettivo degli ingressi/uscite anziché impostare come valore predefinito i valori SSA (ad es. %0, %1 e così via)
  2. Assicurati che per i test venga utilizzato un formato stampato, se esistente.

(G8) Devo includere l'esempio già fornito nelle specifiche? Sì (per la completezza dei test).