StableHLO 翻譯模式

StableHLO 解譯器的主要目的在於提供參照 根據其 StableHLO 運算的語意 規格。次要目標是 讓導入作業密切遵循 同樣是涵蓋最複雜運算 (例如 Convolution) 的語意 Gather/ScatterDotGeneral

目前 OpenXLA 支援解讀 96 片規格中的 91 StableHLO 運算。其餘 3 項作業 (FftOpRngOpRngBitGeneratorOp) 語氣記錄 spec.md,同時將 已完成有關如何繼續的初步調查 (請參閱 status.md ,即可查看完整的作業清單及其最新狀態)。最終 我們也會視需要根據需要的社群加強宣傳功能。

範圍

我們將 StableHLO 對機會分類為 11 種類別,包括 118 次操作 (請見附錄)。 參考實作 工作流程會統整導入翻譯工具的工作 適用於 StableHLO 規格中定義的 100% StableHLO 運算。我們 打算在 StableHLO 中完成這個工作流程中的所有或幾乎所有工作 1.0 版。目前在 96 個設有規格的運算中,我們能透過這個規格解讀 91 次操作 OpenXLA (如要瞭解其餘 5 個,請參閱特殊情況)。

規格

翻譯模式的主要需求為 1:1 對應 規格。這個規格允許將翻譯器標準化 導向模組化的優質翻譯實作。

特殊情況

其他

這個類別包含可分解的運算,但尚不清楚。有 是這個類別中的三個特定作業,翻譯程式不支援 關鍵時刻:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp 會歸類為「其他」,但與這個類別中的其他作業不同: 這個運算沒有展開票證,在 StableHLO 中支援這個 建構中。

RngOpRngBitGeneratorOp 可分解為 MHLO 運算,但 分解引入了 XlaRngGetAndUpdateStateOp,這是 MHLO 特有的 運算「WIP」有助於解讀這兩項作業。

這項工具可將這個類別中其餘的運算轉換成 StableHLO 運算, 翻譯模式支援位於 hlo_expand_main.cc 中的位置。

不在 HLO 上

除了指定的運算之外,這個類別還包含 8 個未指定的作業 (請參閱 StableHLO 作業類別)。 已從 StableHLO 中移出這些作業大多已有票證 mhlo 到 將其轉換為 StableHLO 對等的運算。

這項工具可將這個類別中其餘運算轉換成對等 StableHLO 運算 位於 mlir-hlo-opt.cc 中。

量化

量化類型的 stablehlo.constant 作業翻譯支援為 不支援並追蹤 #1691

使用說明

建立參考解譯器

您可以透過 Bazel 或 CMake (建議) 建構及測試解譯器。全螢幕 操作說明,請參閱 README.md

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

如要執行口譯,我們有翻譯工具可解讀 StableHLO 程式 以 MLIR 撰寫而成

stablehlo-translate --interpret <path/to/program>

翻譯方言

Interpreter 方言包含各種與 翻譯。具體來說,interpreter.run_parallel (請參閱 InterpreterOps.td for op 語意和範例用途) op 允許解讀 Distribution Ops 配合社區需求而新增的公用事業計劃。

檢查方言

Check 方言是用來比較翻譯執行階段值與預期值 輕鬆分配獎金StableHLO 程式輸出內容可透過各種檢查作業測試 (請參閱 CheckOps.td (適用於運算語意及使用範例)。

撰寫測試計畫

我們使用 LLVM 的「lit」工具 並比對產生的檔案,以便與解譯器輸出內容進行比較 (請參閱 stablehlo/tests/interpret)。 例如測試範例)。

正在測試 AddOp (樣本來源: interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

如要執行「發行」類別中的測試作業,您必須透過 「interpreter.run_parallel」公用程式運算。

正在測試 AllReduceOp (樣本來源: all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

對 StableHLO 進行偵錯

按照 StableHLO 建構步驟,StableHLO 二進位檔適用於 stablehlo/tools 應位於 /build/bin。常見的偵錯工具,例如 GDB 可逐步執行程式碼:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

附錄

轉換其他作業

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

不在 HLO 作業中轉換

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

StableHLO 作業類別

類別 記憶術 總計
119
控制流程 after_all、case、if、Optimization_barrier,而 5
資料遷移 Broadcast_in_dim, 串連, 動態_slice, dynamic_update_slice, 收集, 填充, 填色, 重塑, 反向, 散佈, 切片, 排序, 轉置 12
發布 all_gather、all_reduce、all_to_all、collective_permute、infeed、outfeed、partition_id、recv、recv、縮減_scatter、replica_id、send 11
動態 dynamic_broadcast_in_dim、 dynamic_conv、 dynamic_gather、 dynamic_iota、 dynamic_pad、 dynamic_reshape、get_dimension_size、real_dynamic_slice、set_dimension_size 9
Elementwise abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor 48
擴充性 custom_call、get_tuple_element、tuple 3
其他 batch_norm_grad、batch_norm_inference,batch_norm_training, cholesky, 常數, fft, iota, rng, rng_bit_generator, triangular_solve 10
模組性 通話, ficc, 模組, 返回 4
不在 HLO 上 廣播、create_token、cross-replica-sum、dot、einsum、torch_index_select、unary_einsum 8
量化 uniform_dequantize、uniform_quantize 2
縮小 卷積, 點_一般, 縮減, 縮減視窗, 選取_and_scatter 5