The main goal of the StableHLO interpreter is to provide a reference
implementation to the semantics of StableHLO opset according to its
specification. The secondary goal is for the implementation to closely follow
the spec, favoring readability over performance, to provide additional clarity
to the semantics of even the most involved operations like Convolution
,
Gather
/Scatter
, and DotGeneral
.
At the moment, OpenXLA supports the interpretation of 91 out of 96 specced
StableHLO ops. The remaining 3 ops (FftOp
, RngOp
, RngBitGeneratorOp
) have
their semantics documented in
spec.md, and have
completed initial investigations on how to move forward (see
status.md
for a complete list of ops and its latest status). These final
enhancements will be implemented on an as-needed community basis.
Scope
We categorized the StableHLO opset into 11 categories consisting of 118 ops in total (see Appendix). Reference Implementation workstream organizes the work on implementing an interpreter for 100% of StableHLO ops as defined in the StableHLO specification. We are planning to complete all or almost all work in this workstream in StableHLO v1.0. Of the 96 ops that have a spec currently, we can interpret 91 ops through OpenXLA (see Special Cases for the remaining 5).
Specification
The main requirement for the interpreter is to have 1:1 correspondence with the spec. The spec allows standardization of the interpreter across similar ops that lead to modular, high quality implementation of the interpreter.
Special Cases
Miscellaneous
This category has decomposable ops whose future is unclear at the moment. There are three specced ops in this category that the interpreter does not support at the moment:
FftOp
RngOp
RngBitGeneratorOp
FftOp
is categorized as Miscellaneous, but unlike other ops in this category,
this op does not have an expander pass, and supporting this in StableHLO is a
WIP.
RngOp
and RngBitGeneratorOp
can be decomposed into MHLO ops, but the
decomposition introduces a XlaRngGetAndUpdateStateOp
which is an MHLO specific
op. Supporting interpretation of these two ops is a WIP.
The tool to convert remaining ops in this category to StableHLO ops that the interpreter supports resides in hlo_expand_main.cc.
Not in HLO
Apart from the specced ops, this category consists of 8 unspecced ops (see StableHLO Ops Categories) which are planned to be moved out of StableHLO. Most of these ops have existing passes in mhlo to convert them to StableHLO equivalent ops.
The tool to convert remaining ops in this category to equivalent StableHLO ops that the interpreter supports resides in mlir-hlo-opt.cc.
Quantization
Interpreter support for stablehlo.constant
operation with quantized type is
unsupported and tracked via
#1691.
Usage Instructions
Building the Reference Interpreter
The interpreter can be built and tested via Bazel or CMake (preferred). For full instructions, see README.md.
Bazel:
bazel build //...
CMake:
mkdir -p build && cd build
cmake .. -GNinja \
-DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=On \
-DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir
To run the interpreter, we have a translate tool to interpret StableHLO programs written in MLIR.
stablehlo-translate --interpret <path/to/program>
The Interpreter Dialect
The Interpreter
dialect contains various utility ops related to the
interpreter. Specifically, the interpreter.run_parallel
(see
InterpreterOps.td
for op semantics and example usage) op allows interpretation of Distribution ops, and more
utilities plan to be added based on community needs.
The Check Dialect
The Check
dialect is used to compare interpreter runtime values to expected
values. StableHLO program outputs can be tested via various check ops (see
CheckOps.td
for op semantics and example usage).
Writing Test Programs
We use LLVM's lit tool to run and compare against generated file to diff against the output of the interpreter (see stablehlo/tests/interpret for example tests).
Testing AddOp
(sample from
interpret_add.mlir):
// RUN: stablehlo-translate --interpret %s
func.func @add_op_scalar() {
%0 = stablehlo.constant dense<2> : tensor<i4>
%1 = stablehlo.constant dense<3> : tensor<i4>
%2 = stablehlo.add %0, %1 : tensor<i4>
check.expect_eq_const %2, dense<5> : tensor<i4>
func.return
}
Testing ops in the Distribution category requires running it via the
interpreter.run_parallel
utility op.
Testing AllReduceOp
(sample from
all_reduce.mlir):
// RUN: stablehlo-translate --interpret %s
module @cross_replica {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<4xi64>) -> tensor<4xi64>
return %result : tensor<4xi64>
}
func.func public @main() {
%inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
%results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
programs=[[@all_reduce], [@all_reduce]]
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
func.return
}
}
Debugging StableHLO
Following the StableHLO build steps, the StableHLO binaries for tools in
stablehlo/tools
should reside in /build/bin
. Common debugging tools like
GDB can be used to step through the code:
gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir
Appendix
Convert Miscellaneous Ops
# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>
# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>
# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>
# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>
# constant
# Supported in StableHLO interpreter.
# fft
# TBD
# iota
# Supported in StableHLO interpreter.
# rng
# TBD
# rng_bit_generator
# TBD
# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>
Convert Not In HLO Ops
# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>
# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>
# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>
# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>
StableHLO Ops Categories
Categories | Mnemonics | Total |
---|---|---|
119 | ||
Control Flow | after_all, case, if, optimization_barrier, while | 5 |
Data Movement | broadcast_in_dim, concatenate, dynamic_slice, dynamic_update_slice, gather, pad, reshape, reverse, scatter, slice, sort, transpose | 12 |
Distribution | all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, reduce_scatter, replica_id, send | 11 |
Dynamism | dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_dimension_size, real_dynamic_slice, set_dimension_size | 9 |
Elementwise | abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor | 48 |
Extensibility | custom_call, get_tuple_element, tuple | 3 |
Miscellaneous | batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve | 10 |
Modularity | call, func, module, return | 4 |
Not In HLO | broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum | 8 |
Quantization | uniform_dequantize, uniform_quantize | 2 |
Reduction | convolution, dot_general, reduce, reduce_window, select_and_scatter | 5 |