StableHLO インタープリタ

StableHLO インタープリタの主な目的は、 を StableHLO opset のセマンティクスに 仕様。二次的な目標は、これらの目標に忠実に従って より明確になるように、パフォーマンスよりも読みやすさを優先しています。 Convolution のような最も複雑なオペレーションのセマンティクスまで、 Gather/ScatterDotGeneral

現在、OpenXLA は 96 個のうち 91 個を解釈できます。 StableHLO オペレーション。残りの 3 つのオペレーション(FftOpRngOpRngBitGeneratorOp)は、 まとめられています。 spec.md、 先へ進む方法に関する初期調査を完了( status.md をご覧ください)。これらのファイナル 機能強化は必要に応じてコミュニティベースで実装されます。

範囲

Google は StableHLO オプセットを 11 のカテゴリに分類し、 合計(付録を参照)。 リファレンス実装 インタープリタの実装に関する作業を整理するワークストリーム StableHLO 仕様で定義されている StableHLO オペレーションの 100% 向けです。私たちは このワークストリームのすべてまたはほぼすべての作業を StableHLO で完了する予定 。現在仕様を持っている 96 個の op のうち、 OpenXLA(残りの 5 つのケースについては、特殊なケースをご覧ください)。

仕様

通訳者の主な要件は、 仕様。この仕様により、類似のオペレーション全体でインタープリタを標準化できます。 モジュール化された高品質なインタープリタ実装につながります。

特殊なケース

その他

このカテゴリには分解可能な op があり、現時点では将来が不確定です。そこで、 このカテゴリの特定の 3 つの演算のうち、インタープリタがサポートしていないもの 目を向けます。

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp は「その他」に分類されますが、このカテゴリの他のオペレーションとは異なり、 このオペレーションにはエクスパンダー パスがありません。StableHLO でこれをサポートするには、 作業中。

RngOpRngBitGeneratorOp は MHLO 演算に分解できますが、 分解により、MHLO 固有の XlaRngGetAndUpdateStateOp が導入されます。 演算これら 2 つの op の解釈のサポートは WIP です。

このカテゴリの残りのオペレーションを StableHLO オペレーションに変換するためのツールです。 インタープリタのサポートは hlo_expand_main.cc にあります。

HLO にはありません

このカテゴリは、指定された op とは別に、8 つの指定されていない op で構成されます( StableHLO 運用カテゴリ)。 StableHLO から移動しました。これらの運用のほとんどには mhlo から StableHLO の同等のオペレーションに変換します。

このカテゴリの残りのオペレーションを同等の StableHLO オペレーションに変換するツール mlir-hlo-opt.cc にあります。

量子化

量子化型を使用した stablehlo.constant 演算のインタープリタ サポートは次のとおりです。 非対応で、次を使用してトラッキング: #1691

使用方法

リファレンス インタープリタの構築

Bazel または CMake(推奨)を使用してインタープリタのビルドとテストを行うことができます。フル README.md をご覧ください。

Bazel:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

インタープリタを実行するために、StableHLO プログラムを解釈するための翻訳ツールが用意されています。 MLIR で記述します

stablehlo-translate --interpret <path/to/program>

通訳言語

Interpreter 言語には、 説明します。具体的には interpreter.run_parallelInterpreterOps.td (オペレーション セマンティクスと使用例など)を確認できます。 公益事業はコミュニティのニーズに基づいて追加される予定です。

チェック言語

Check 言語は、インタープリタのランタイム値を期待される値と比較するために使用されます。 使用できます。StableHLO プログラムの出力は、さまざまなチェック オペレーション( CheckOps.td をご覧ください)。

テストプログラムの作成

LLVM の lit ツールを使用して 生成されたファイルと比較し、インタープリタの出力と比較する (stablehlo/tests/interpret を参照) (テストなど)。

AddOp のテスト( interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

分布カテゴリの運用をテストするには、 interpreter.run_parallel ユーティリティ オペレーション。

AllReduceOp のテスト( all_reduce.mlir):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

StableHLO のデバッグ

StableHLO ビルド手順の後、StableHLO バイナリが stablehlo/tools/build/bin に配置する必要があります。たとえば、 GDB を使用すると、コードをステップ実行することができます。

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

付録

その他のオペレーションの変換

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

HLO オペレーションにないコンバージョン

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

StableHLO オペレーションのカテゴリ

カテゴリ 記憶 合計
119
制御フロー after_all、case、if、optimization_barrier、 5
データの移動 Broadcast_in_dim、concatenate、dynamic_slice、dynamic_update_slice、収集、パッド、再形成、逆方向、散布性、スライス、並べ替え、転置 12
分布 all_gather、all_reduce、all_to_all、collective_permute、infeed、outfeed、partition_id、recv、reduce_scatter、replica_id、send 11
ダイナミズム dynamic_broadcast_in_dim、dynamic_conv、dynamic_gather、dynamic_iota、dynamic_pad、dynamic_reshape、get_dimension_size、real_dynamic_slice、set_dimension_size 9
Elementwise abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor 48
拡張性 custom_call、get_tuple_element、tuple 3
その他 batch_norm_grad、batch_norm_inference、batch_norm_training、コレスキー、定数、fft、iota、rng、rng_bit_generator、triangular_solve 10
モジュール性 call、func、module、return 4
HLO になし subnet、create_token、cross-replica-sum、dot、einsum、torch_index_select、unary_einsum 8
量子化 uniform_dequantize、uniform_quantize 2
削減 convolution、dot_general、reduce、reduce_window、select_and_scatter 5