StableHLO 인터프리터

StableHLO 인터프리터의 주요 목표는 의미 체계에 따라 StableHLO 오퍼레이션의 의미 체계에 지정할 수도 있습니다 두 번째 목표는 구현의 각 단계에서 성능보다 가독성을 우선시하여 더욱 명확성을 높입니다. Convolution과 같이 가장 관련된 작업의 시맨틱스까지 Gather/ScatterDotGeneral입니다.

현재 OpenXLA는 96개의 파일 중 91개의 텍스트에 대한 해석을 StableHLO 오퍼레이션입니다. 나머지 3개 작업 (FftOp, RngOp, RngBitGeneratorOp)은 해당 시맨틱스는 spec.md를 사용하고 있고 진행 방법에 대한 초기 조사가 완료된 것입니다 (자세한 내용은 status.md 를 참조하세요. 이러한 최종 개선사항이 필요에 따라 커뮤니티 기반으로 구현될 것입니다.

범위

우리는 StableHLO 작업을 합계 (부록 참고) 참조 구현 워크스트림은 통역사를 구현하는 작업을 StableHLO 사양에 정의된 StableHLO 작업의 100% 에 적용됩니다. Google은 StableHLO에서 이 작업 흐름의 모든 작업 또는 거의 모든 작업을 완료할 계획을 세웁니다. v1.0 현재 사양이 있는 96개의 작업 중에서 91개의 작업을 OpenXLA (나머지 5개는 특수 케이스 참고)

사양

통역사의 주요 요구사항은 사양 사양을 사용하면 다음 작업을 수행하는 유사한 작업 간에 인터프리터를 표준화할 수 있습니다. 인터프리터를 고품질 모듈식으로 구현할 수 있습니다.

특수 사례

기타

이 카테고리에는 현재 미래가 불확실한 분해 가능한 작업이 있습니다. 거기 통역사가 지원하지 않는 세 가지 특수 작업이라고 할 수 있습니다. 지금:

  • FftOp
  • RngOp
  • RngBitGeneratorOp

FftOp는 기타로 분류되지만 이 카테고리의 다른 오퍼레이션과 달리 이 작업에는 확장기 패스가 없으며 StableHLO에서 이를 지원하는 것은 WIP

RngOpRngBitGeneratorOp는 MHLO 작업으로 분해할 수 있지만 분해는 MHLO 전용 XlaRngGetAndUpdateStateOp를 도입합니다. op. 이 두 작업의 해석을 지원하는 것이 WIP입니다.

이 카테고리의 나머지 작업을 인터프리터는 hlo_expand_main.cc에 상주합니다. <ph type="x-smartling-placeholder"></ph>

HLO에 없음

특수 Ops를 제외하고 이 범주는 8개의 비사양 op으로 구성됩니다. StableHLO 운영 카테고리)에 이미 적용될 수 있습니다. StableHLO에서 이동했습니다. 이러한 작업에는 대부분 mhlo~ StableHLO에 상응하는 작업으로 변환합니다.

이 카테고리의 나머지 작업을 동등한 StableHLO 작업으로 변환하는 도구입니다. 인터프리터가 지원하는 라이브러리는 mlir-hlo-opt.cc에 있습니다. <ph type="x-smartling-placeholder"></ph>

양자화

양자화된 유형의 stablehlo.constant 연산에 대한 인터프리터 지원: 지원되지 않으며 다음을 통해 추적됨 #1691

사용 지침

참조 인터프리터 빌드

인터프리터는 Bazel 또는 CMake (권장)를 통해 빌드하고 테스트할 수 있습니다. 완전형 README.md를 참고하세요.

바젤:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

인터프리터를 실행할 수 있도록 StableHLO 프로그램을 해석하는 번역 도구가 있습니다. MLIR에서 작성해야 합니다.

stablehlo-translate --interpret <path/to/program>

통역 방언

Interpreter 언어에는 통역해 줍니다. 특히 interpreter.run_parallel( InterpreterOps.td 작업의 시맨틱스 및 예시 사용 예시에 대해) : 분산 오퍼레이션 등에 대한 해석을 허용합니다. 유틸리티는 커뮤니티의 필요에 따라 추가될 계획입니다

확인 언어

Check 언어는 인터프리터 런타임 값을 예상과 비교하는 데 사용됩니다. 값으로 사용됩니다. StableHLO 프로그램 출력은 다양한 검사 작업을 통해 테스트할 수 있습니다( CheckOps.td 참조). <ph type="x-smartling-placeholder"></ph>

테스트 프로그램 작성

LLVM의 lit 도구를 사용하여 생성된 파일과 비교하여 인터프리터의 출력과 비교합니다. (stablehlo/tests/interpret 참고) 참조).

AddOp 테스트( interpret_add.mlir):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

배포 카테고리에서 작업을 테스트하려면 다음을 통해 interpreter.run_parallel 유틸리티 작업

AllReduceOp 테스트( all_reduce.mlir)가 됩니다.

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

StableHLO 디버깅

StableHLO 빌드 단계에 따라 다음의 툴에 대한 StableHLO 바이너리는 stablehlo/tools/build/bin에 있어야 합니다. 일반적인 디버깅 도구, GDB를 사용하여 코드를 단계별로 실행할 수 있습니다.

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

부록

기타 작업 변환

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

HLO 운영에 있지 않은 전환

# broadcast
mlir-hlo-opt -mhlo-legalize-broadcast-to-broadcast-in-dim <path/to/input>

# create_token
mlir-hlo-opt -mhlo-legalize-create-token-to-after-all <path/to/input>

# cross-replica-sum
mlir-hlo-opt -mhlo-legalize-cross-replica-sum-to-all-reduce <path/to/input>

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

StableHLO 작업 카테고리

카테고리 연상 기호 합계
119
제어 흐름 after_all, case, if, optimization_barrier, 동안 5
데이터 이동 broadcast_in_dim, 연결, dynamic_slice, dynamic_update_slice, 수집, 패딩, 형태 변경, 역방향, 분산형, 슬라이스, 정렬, 바꾸기 12
배포 all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, 감소_scatter, 복제본 ID, 전송 11
역동주의 dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_Dimensions_size, real_dynamic_slice, set_Dimensions_size 9
Elementwise 추가, 추가, 추가, 및,을, 더 48
확장성 커스텀 호출, get_tuple_element, 튜플 3
기타 batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, 상수, fft, iota, rng, rng_bit_generator, triangular_solve 10
모듈성 호출, 함수, 모듈, 반환 4
HLO에 포함되지 않음 broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum 8
양자화 유니폼_비양자화, 균일_양자화 2
절감 컨볼루션, 점_일반, 감소, 감소_창, 선택_및_분해 5