Interprete StableHLO

L'obiettivo principale dell'interprete StableHLO è fornire un riferimento la semantica dell'opset StableHLO in base al suo e la specifica del prodotto. L'obiettivo secondario è che l'implementazione segua attentamente la specifica, prediligendo la leggibilità rispetto alle prestazioni, per fornire ulteriore chiarezza alla semantica anche delle operazioni più complesse come Convolution, Gather/Scatter e DotGeneral.

Al momento, OpenXLA supporta l’interpretazione di 91 su 96 specificati operazioni StableHLO. Le tre operazioni rimanenti (FftOp, RngOp, RngBitGeneratorOp) hanno registrato la loro semantica documentata spec.md e hanno hanno completato le indagini iniziali su come procedere (vedi status.md per un elenco completo delle operazioni e il relativo stato più recente). Queste ultime i miglioramenti verranno implementati in base alle esigenze della community.

Ambito

Abbiamo categorizzato l'opset StableHLO in 11 categorie costituite da 118 operazioni in totali (vedi Appendice). Implementazione dei riferimenti workstream organizza il lavoro sull'implementazione di un interprete per il 100% delle operazioni StableHLO, come definito nella specifica StableHLO. Stiamo pianificare di completare tutto o quasi tutto il lavoro in questo flusso di lavoro in StableHLO Versione 1.0. Delle 96 operazioni che hanno una specifica attualmente, possiamo interpretarne 91 tramite OpenXLA (vedi Casi speciali per i restanti 5).

Specifica

Il requisito principale per l'interprete è avere una corrispondenza 1:1 con del modello. Le specifiche consentono la standardizzazione dell'interprete in operazioni simili che portano a un'implementazione modulare e di alta qualità dell'interprete.

Casi speciali

Varie

Questa categoria comprende operazioni scomponibili il cui futuro non è chiaro al momento. Là sono tre operazioni specifiche di questa categoria che l'interprete non supporta momento:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp è classificato come Varie, ma a differenza di altre operazioni in questa categoria, questa operazione non ha un pass di espansione e il supporto in StableHLO è In fase di elaborazione.

RngOp e RngBitGeneratorOp possono essere scomposti in operazioni MHLO, ma di decomposizione introduce un XlaRngGetAndUpdateStateOp, che è specifico per l'MHLO op. L'interpretazione di supporto di queste due operazioni è in fase di elaborazione.

Lo strumento per convertire le operazioni rimanenti di questa categoria in operazioni StableHLO l'interpreter supporta risiede in hlo_expand_main.cc.

Non in HLO

A parte le operazioni specifiche, questa categoria è composta da 8 operazioni non specifiche (vedi Categorie di operazioni SttableHLO), che dovrebbero essere è stato spostato fuori da StableHLO. La maggior parte di queste operazioni ha pass esistenti in mhlo a convertile in operazioni equivalenti StableHLO.

Lo strumento per convertire le operazioni rimanenti di questa categoria in operazioni StableHLO equivalenti supportata dall'interprete si trova in mlir-hlo-opt.cc.

Quantizzazione

Il supporto dell'interprete per l'operazione stablehlo.constant con il tipo quantizzato è non supportato e monitorato tramite #1691.

Istruzioni per l'uso

Creare l'interprete di riferimento

L'interprete può essere creato e testato tramite Bazel o CMake (opzione preferita). Per la carica completa istruzioni, vedi README.md.

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

Per eseguire l'interprete, abbiamo uno strumento di traduzione per interpretare i programmi StableHLO scritto in MLIR.

stablehlo-translate --interpret <path/to/program>

Il dialetto dell'interprete

Il dialetto Interpreter contiene varie operazioni di utilità relative all' come interprete. In particolare, interpreter.run_parallel (vedi InterpreterOps.td per la semantica delle operazioni e l'utilizzo degli esempi), l'operazione consente l'interpretazione delle operazioni di distribuzione e altro ancora aziende di pubblici servizi prevedono di essere aggiunti in base alle esigenze della comunità.

Controlla dialetto

Il dialetto Check viene utilizzato per confrontare i valori di runtime dell'interprete con quelli previsti e i relativi valori. Gli output del programma StableHLO possono essere testati tramite varie operazioni di controllo (vedi CheckOps.td per la semantica delle operazioni e per l'utilizzo di esempio).

Redazione di programmi di test

Usiamo lo strumento lit di LLVM per eseguire confronta con il file generato per confrontare le differenze con l'output dell'interprete (vedi stablehlo/tests/interpret ad esempio nei test).

Test di AddOp in corso (esempio di interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

Le operazioni di test nella categoria Distribuzione richiedono l'esecuzione tramite la interpreter.run_parallel utility op.

Test di AllReduceOp in corso (esempio di all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

Debug di StableHLO

Seguendo i passaggi di build di StableHLO, i file binari StableHLO per gli strumenti stablehlo/tools deve risiedere in /build/bin. Strumenti di debug comuni come È possibile usare GDB per analizzare il codice:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

Appendice

Converti operazioni varie

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

Converti non in operazioni HLO

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

Categorie di operazioni StableHLO

Categorie Mnemotecniche Totale
119
Flusso di controllo dopo_tutto, caso, if, barriera_ottimizzazione, mentre 5
Spostamento dei dati broadcast_in_dim, concatenare, dynamic_slice, dynamic_update_slice, raccogliere, pad, rimodellare, invertire, dispersione, sezione, ordinare, trasposti 12
Distribuzione all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, ridurre_scatter, replica_id, inviare 11
Dinamismo dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_dimension_size, real_Dynamic_slice e set_dimension_size 9
Elementwise o, 48
Estensibilità custom_call, get_tuple_element, tupla 3
Varie batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, costante, fft, iota, rng, rng_bit_generator, triangular_solve 10
Modularità chiamata, funzione, modulo, ritorna 4
Non presente nell'HLO broadcast, create_token, cross-replica-sum, punto, einsum, torch_index_select, unary_einsum 8
Quantizzazione uniform_dequantize, uniform_quantize 2
Riduzione convoluzione, punto_generale, riduci, riduci_finestra, seleziona_e_dispersione 5