Interpréteur StableHLO

L'objectif principal de l'interpréteur StableHLO est de fournir une référence à la sémantique de l'opset StableHLO selon sa spécifique. L'objectif secondaire est que l'implémentation suive de près la spécification, en privilégiant la lisibilité par rapport aux performances, pour plus de clarté. à la sémantique des opérations les plus complexes, comme Convolution, Gather/Scatter et DotGeneral.

Actuellement, OpenXLA prend en charge l'interprétation de 91 spécifiés sur 96 StableHLO Ops Les trois opérations restantes (FftOp, RngOp, RngBitGeneratorOp) ont leur sémantique documentée dans spec.md et avoir mené des recherches initiales afin de déterminer la marche à suivre (voir status.md pour obtenir la liste complète des opérations et leur état le plus récent). Ces dernières améliorations seront apportées au besoin par la communauté.

Champ d'application

Nous avons classé l'opset StableHLO en 11 catégories, soit 118 ops dans (voir l'annexe). Implémentation de référence organise le travail d'implémentation d'un interprète. pour la totalité des opérations StableHLO, comme défini dans la spécification StableHLO. Nous prévoient de terminer la totalité ou presque tous les travaux de cet flux de travail dans StableHLO version 1.0. Sur les 96 opérations associées actuellement à une spécification, nous pouvons en interpréter 91 via OpenXLA (voir les cas particuliers pour les cinq autres).

Spécification

La principale condition pour l'interprète est d'avoir une correspondance 1:1 avec . Cette spécification permet de standardiser l'interpréteur pour des opérations similaires conduisent à une implémentation modulaire et de haute qualité de l'interprète.

Cas particuliers

Divers

Cette catégorie comprend des opérations décomposables dont l'avenir n'est pas clair pour le moment. Il y trois opérations spécifiées dans cette catégorie ne sont pas prises en charge par l'interprète pour le moment:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp appartient à la catégorie "Divers", mais contrairement aux autres opérations de cette catégorie, l'opération n'a pas de pass d'expansion, et la prise en charge dans StableHLO est En cours.

RngOp et RngBitGeneratorOp peuvent être décomposés en opérations MHLO, mais la décomposition introduit un XlaRngGetAndUpdateStateOp, qui est un MHLO spécifique op. L'interprétation de ces deux opérations est en cours.

Outil permettant de convertir les opérations restantes dans cette catégorie en opérations StableHLO pris en charge par l'interpréteur se trouve dans hlo_expand_main.cc.

Pas dans HLO

Outre les opérations spécifiées, cette catégorie se compose de huit opérations non spécifiées (voir la StableHLO Ops Categories), qui devraient être hors de StableHLO. La plupart de ces opérations ont des cartes existantes dans mhlo à les convertir en opérations équivalentes StableHLO.

Outil permettant de convertir les opérations restantes dans cette catégorie en opérations StableHLO équivalentes pris en charge par l'interpréteur se trouve dans mlir-hlo-opt.cc.

Quantification

La compatibilité de l'interpréteur pour l'opération stablehlo.constant avec un type quantifié est non pris en charge et suivi via #1691.

Instructions d'utilisation

Créer l'interpréteur de référence

L'interprète peut être créé et testé via Bazel ou CMake (recommandé). Pleine instructions, consultez README.md.

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

Pour exécuter l'interpréteur, nous disposons d'un outil de traduction permettant d'interpréter les programmes StableHLO. écrites en MLIR.

stablehlo-translate --interpret <path/to/program>

Dialecte de l'interprète

Le dialecte Interpreter contient diverses opérations utilitaires liées au un interprète. Plus précisément, interpreter.run_parallel (consultez InterpreterOps.td pour la sémantique des opérations et les exemples d'utilisation), l'opération permet d'interpréter les opérations de distribution, et plus encore les fournisseurs d'énergie prévoient d'être ajoutés en fonction des besoins de la communauté.

Dialecte de vérification

Le dialecte Check permet de comparer les valeurs d'exécution de l'interpréteur aux valeurs attendues. valeurs. Les sorties du programme StableHLO peuvent être testées via différentes opérations de vérification (voir CheckOps.td pour la sémantique des opérations et les exemples d'utilisation).

Écriture de programmes de test

Nous utilisons l'outil lit de LLVM pour exécuter comparer avec le fichier généré pour différencier les différences avec la sortie de l'interpréteur (voir stablehlo/tests/interpret). (par exemple, pour les tests).

Test de AddOp (échantillon de interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

Les opérations de test de la catégorie "Distribution" nécessitent de les exécuter via la Opération utilitaire interpreter.run_parallel.

Test de AllReduceOp (échantillon de all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

Déboguer StableHLO

Après les étapes de compilation StableHLO, les binaires StableHLO pour les outils de stablehlo/tools doit se trouver dans /build/bin. Les outils de débogage courants tels que GDB peut être utilisé pour parcourir le code:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

Annexe

Opérations de conversion diverses

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

Opérations de conversion Not In HLO

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

Catégories d'opérations StableHLO

Catégories Mnémonique Total
119
Flux de contrôle after_all, case, if, optimization_barrier, tandis que 5
Transfert de données broadcast_in_dim, concatène, dynamique_slice, dynamic_update_slice, rassembler, pad, remodeler, inverser, disperser, trancher, trier, transposer 12
Distribution all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, Reduce_ dispersion, instance_id, envoyer 11
Dynamique dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_dimension_size, real_dynamic_slice, set_dimension_size 9
Elementwise abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor 48
Extensibilité appel personnalisé, get_tuple_element, tuple 3
Divers batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constante, fft, iota, rng, rng_bit_generator, solution_triangulaire 10
Modularité appel, func, module, retour 4
Pas dans HLO diffusion, create_token, cross-replica-sum, point, einsum, torch_index_select, unary_einsum 8
Quantification "uniform_dequantize" et "uniform_quantize" 2
Réduction convolution, point_general, réduire, réduire_fenêtre, select_and_ disper 5