StableHLO 解释器

StableHLO 解释器的主要目标是提供 StableHLO 运算集的语义 规范次要目标是在实施过程中密切关注 注重可读性而非性能,以进一步阐明规范 即使是最复杂的运算(例如 Convolution)的语义, Gather/Scatter,以及 DotGeneral

目前,OpenXLA 支持以 StableHLO 操作。其余 3 个操作(FftOpRngOpRngBitGeneratorOp) 请参阅 spec.md,并且 已完成有关如何继续操作的初步调查(请参阅 status.md 获取操作的完整列表及其最新状态)。这些最终 我们会根据需要社区实施改进措施。

范围

我们将 StableHLO 运算集分为 11 个类别,包括 (请参阅附录)。 参考实现 Workstream 会组织实现解释器的工作 100% StableHLO 操作(如 StableHLO 规范中所定义)。我们是 计划在 StableHLO 中完成此工作流中的所有或几乎所有工作 1.0 版。在目前具有规范的 96 个操作中,我们可以将 91 个操作解读为 OpenXLA(对于其余 5 种,请参阅特殊情况)。

规范

对译员的主要要求是与 规范该规范允许在执行如下操作时, 有助于实现高质量、模块化的解释器。

特殊情况

其他

此类别具有可组合操作,其未来尚不明确。那里 解释器不支持此类别中的三种特定操作 这一刻:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp 属于“其他”类别,但与这一类别中的其他操作不同, 此操作没有扩展器传递,在 StableHLO 中支持扩展是 制作中。

RngOpRngBitGeneratorOp 可以分解为 MHLO 操作,但 分解引入了 XlaRngGetAndUpdateStateOp,这是 MHLO 专用 操作支持解释这两种操作是一项 WIP。

此工具可将此类别中的其余操作转换为 解释器支持的位于 hlo_expand_main.cc 中。

不在 HLO 中

除了指定操作外,该类别还包括 8 个未指定操作(请参阅 StableHLO Ops 类别) 已从 StableHLO 中移出其中的大多数运营机构都有 mhlo 到 将它们转换为 StableHLO 等效操作。

用于将此类别中的其余操作转换为等效的 StableHLO 操作的工具 位于 mlir-hlo-opt.cc 中。

量化

对量化类型的 stablehlo.constant 运算的解释器支持现为 不受支持,通过 #1691

使用说明

构建参考解释器

解释器可以通过 Bazel 或 CMake 进行构建和测试(首选)。完整 请参阅 README.md

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

为了运行解释器,我们提供了一个转换工具来解释 StableHLO 程序 用 MLIR 编写。

stablehlo-translate --interpret <path/to/program>

口译方言

Interpreter 方言包含与 翻译。具体而言,interpreter.run_parallel(请参阅 InterpreterOps.td 操作语义和示例用法)操作允许解释分布操作,等等 公共事业服务提供商。

校对方言

Check 方言用于将解释器运行时值与预期值进行比较 值。可以通过各种检查操作来测试 StableHLO 程序输出(请参阅 CheckOps.td 了解操作语义和用法示例)。

编写测试计划

我们使用 LLVM 的 lit 工具来运行和 与生成的文件进行比较,以与解释器的输出进行比较 (请参阅 stablehlo/tests/interpret 例如测试)。

测试 AddOp(取自 interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

Distribution 类别中的测试操作需要通过 interpreter.run_parallel 实用程序操作。

测试 AllReduceOp(取自 all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

调试 StableHLO

按照 StableHLO 构建步骤, stablehlo/tools 应位于 /build/bin 中。常用的调试工具,例如 GDB 可用于单步调试代码:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

附录

转换其他操作

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

非 HLO Ops 转化

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

StableHLO 操作类别

类别 助记符 总计
119
控制流 after_all、case、if、optimization_步长、同时 5
数据移动 Broadcast_in_dim、concatenate、 dynamic_slice、 dynamic_update_slice、聚集、填充、重塑、反转、散点、切片、排序、转置 12
分布 all_gather、all_reduce、all_to_all、collective_permute、infeed、outfeed、partition_id、recv、reduce_scatter、replica_id、send 11
动感 dynamic_broadcast_in_dim、 dynamic_conv、 dynamic_gather、 dynamic_iota、 dynamic_pad、 dynamic_reshape、get_dimension_size、real_dynamic_slice、set_dimension_size 9
Elementwise {0} 48
可扩展性 custom_call、get_tuple_element、tuple 3
其他 batch_norm_grad、batch_norm_inference、batch_norm_training、cholesky、常量、fft、iota、rng、rng_bit_generator、triangular_solve 10
模块化 调用, 模块, 返回, call, func, module, return 4
不在 HLO 中 Broadcast、create_token、cross-replica-sum、dot、einsum、trch_index_select、unary_einsum 8
量化 uniform_dequantize、uniform_quantize 2
弱读 convolution、dot_general、reduce、reduce_window、select_and_scatter 5