Intérprete de StableHLO

El objetivo principal del intérprete de StableHLO es proporcionar una referencia implementación de la semántica del opset StableHLO de acuerdo con su especificación. El objetivo secundario es que la implementación siga de cerca la especificación, lo que favorece la legibilidad sobre el rendimiento para brindar más claridad a la semántica, incluso de las operaciones más complejas, como Convolution Gather/Scatter y DotGeneral.

Por el momento, OpenXLA admite la interpretación de 91 de 96 Operaciones de StableHLO. Las 3 operaciones restantes (FftOp, RngOp, RngBitGeneratorOp) tienen su semántica documentada en spec.md y tienen hemos realizado investigaciones iniciales sobre cómo avanzar (consulta status.md para obtener una lista completa de las operaciones y su estado más reciente). Estos cambios finales se implementarán mejoras según sea necesario en la comunidad.

Alcance

Clasificamos el conjunto de operaciones StableHLO en 11 categorías con 118 operaciones en totales (consulta el Apéndice). Implementación de referencia de trabajo organiza el trabajo para implementar un intérprete para el 100% de las operaciones de StableHLO, según se define en la especificación de StableHLO. Somos planea completar la totalidad o casi todo el trabajo en este flujo de trabajo en StableHLO. Versión 1.0. De las 96 operaciones que tienen una especificación actualmente, podemos interpretar 91 operaciones OpenXLA (consulta los casos especiales para los 5 restantes).

Especificación

El requisito principal del intérprete es tener correspondencia 1:1 con el específica. La especificación permite la estandarización del intérprete en operaciones similares que conducir a una implementación modular y de alta calidad del intérprete.

Casos especiales

Varios

Esta categoría tiene operaciones descomponibles cuyo futuro no está claro en este momento. Hay hay tres operaciones específicas en esta categoría que el intérprete no admite en el momento:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp se clasifica como Varias, pero, a diferencia de otras operaciones de esta categoría, esta op no tiene un pase de expansión, y admitirlo en StableHLO es En construcción.

RngOp y RngBitGeneratorOp se pueden dividir en operaciones de MHLO, pero la descomposición introduce un XlaRngGetAndUpdateStateOp, que es un MHLO específico op. La interpretación de respaldo de estas dos operaciones es un trabajo en curso.

La herramienta para convertir las operaciones restantes de esta categoría en operaciones de StableHLO que que admite el intérprete reside en hlo_expand_main.cc.

No está en HLO

Además de las operaciones especificadas, esta categoría consta de 8 operaciones no especificadas (consulta categorías de operaciones de StableHLO), que están planificadas para y salí del StableHLO. La mayoría de estas operaciones tienen pases existentes en mhlo a y convertirlos en operaciones equivalentes a StableHLO.

La herramienta para convertir las operaciones restantes de esta categoría en operaciones de StableHLO equivalentes. que el intérprete admite reside en mlir-hlo-opt.cc.

Cuantización

La compatibilidad con el intérprete para la operación stablehlo.constant con tipo cuantizado es y se hace un seguimiento mediante #1691:

Instrucciones de uso

Creación del intérprete de referencia

El intérprete se puede compilar y probar a través de Bazel o CMake (opción preferida). Total consulta README.md.

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

Para ejecutar el intérprete, tenemos una herramienta de traducción a fin de interpretar los programas de StableHLO. escritas en MLIR.

stablehlo-translate --interpret <path/to/program>

Dialecto de intérprete

El dialecto Interpreter contiene varias operaciones de utilidad relacionadas con la de un intérprete. Específicamente, el interpreter.run_parallel (consulta InterpreterOps.td (para semántica de op y uso de ejemplo), op permite interpretar las operaciones de distribución y mucho más de servicios públicos nuevos según las necesidades de la comunidad.

El dialecto de verificación

El dialecto Check se usa para comparar los valores del entorno de ejecución del intérprete con los esperados. de salida. Los resultados del programa StableHLO se pueden probar a través de varias operaciones de verificación (consulta CheckOps.td para la semántica de op y el uso de ejemplo).

Redacción de programas de prueba

Usamos la herramienta lit de LLVM para ejecutar y comparar con el archivo generado para diferenciar con el resultado del intérprete (consulta stablehlo/tests/interpret). como las pruebas).

Probando AddOp (muestra de interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

Para probar las operaciones de la categoría Distribución, es necesario ejecutarla a través del Operación de utilidad interpreter.run_parallel.

Probando AllReduceOp (muestra de all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

Cómo depurar StableHLO

Siguiendo los pasos de compilación de StableHLO, los objetos binarios de StableHLO para herramientas en stablehlo/tools debe estar en /build/bin. Herramientas de depuración comunes, como GDB se puede utilizar para recorrer el código:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

Apéndice

Convertir otras operaciones

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

Convertir en operaciones de HLO

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

Categorías de operaciones de StableHLO

Categorías Mnemotécnicas Total
119
Flujo de control after_all, case, if, optimization_barrier, mientras que 5
Transferencia de datos broadcast_in_dim, concatenar, porción_dinámica, porción_dinámica_de_actualización, recopilar, padding, redimensionar, revertir, dispersión, porción, ordenar, transponer 12
Distribución all_regan, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, reducir_scatter, replica_id, send 11
Dinamismo dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_dimension_size, real_dynamic_slice, set_dimension_size 9
Elementwise abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor 48
Extensibilidad custom_call, get_tuple_element, tupla 3
Varios Batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constante, fft, iota, rng, rng_bit_generator, triangular_solve 10
Modularidad call, func, module, return 4
No está en HLO transmisión, create_token, suma-réplica cruzada, punto, einsum, torch_index_select, unary_einsum 8
Cuantización uniform_dequantize, uniform_quantize 2
Reducción convolution, punto_general, reduce, reduce_window, select_and_scatter 5