Стабильный переводчик HLO

Основная цель интерпретатора StableHLO — предоставить эталонную реализацию семантики набора операций StableHLO в соответствии с его спецификацией. Второстепенная цель — обеспечить точное соответствие реализации спецификации, отдавая приоритет читаемости, а не производительности, чтобы обеспечить дополнительную ясность семантики даже самых сложных операций, таких как Convolution , Gather / Scatter и DotGeneral .

В настоящий момент OpenXLA поддерживает интерпретацию 91 из 96 описанных операций StableHLO. Семантика оставшихся 2 операций ( RngOp и RngBitGeneratorOp ) описана в spec.md , и завершены предварительные исследования дальнейших шагов (полный список операций и их текущий статус см. в status.md ). Эти окончательные улучшения будут внедряться по мере необходимости сообществом.

Объем

Мы разделили набор операций StableHLO на 11 категорий, включающих в общей сложности 118 операций (см. Приложение ). Рабочая группа по эталонной реализации организует работу по реализации интерпретатора для 100% операций StableHLO, как определено в спецификации StableHLO. Мы планируем завершить всю или почти всю работу в этой группе в StableHLO v1.0. Из 96 операций, для которых в настоящее время существует спецификация, мы можем интерпретировать 91 операцию через OpenXLA (см. Особые случаи для оставшихся 5).

Спецификация

Основное требование к интерпретатору — это соответствие спецификации в соотношении 1:1. Спецификация позволяет стандартизировать интерпретатор для аналогичных операций, что приводит к модульной и высококачественной реализации интерпретатора.

Особые случаи

Разнообразный

В этой категории находятся разложимые операции, будущее которых на данный момент неясно. В этой категории есть три спецификации операций, которые интерпретатор в данный момент не поддерживает:

  • RngOp
  • RngBitGeneratorOp

RngOp и RngBitGeneratorOp можно разложить на операции MHLO, но это разложение вводит операцию XlaRngGetAndUpdateStateOp , которая является специфичной для MHLO. Интерпретация этих двух операций находится в стадии разработки.

Инструмент для преобразования оставшихся операций этой категории в операции StableHLO, поддерживаемые интерпретатором, находится в файле hlo_expand_main.cc .

Не в HLO

Помимо указанных в спецификации операций, эта категория включает 8 неуказанных операций (см. Категории операций StableHLO ), которые планируется вывести из StableHLO. Для большинства из этих операций уже существуют проходы в mhlo для преобразования их в эквивалентные операции StableHLO.

Инструмент для преобразования оставшихся операций этой категории в эквивалентные операции StableHLO, поддерживаемые интерпретатором, находится в файле mlir-hlo-opt.cc .

Квантование

Поддержка интерпретатором операции stablehlo.constant с квантованным типом отсутствует и отслеживается в рамках задачи #1691 .

Инструкция по применению

Создание интерпретатора ссылок

Интерпретатор можно собрать и протестировать с помощью Bazel или CMake (предпочтительно). Полные инструкции см. в файле README.md .

Базель:

bazel build //...

CMake:

mkdir -p build && cd build

cmake .. -GNinja \
  -DLLVM_ENABLE_LLD="$LLVM_ENABLE_LLD" \
  -DCMAKE_BUILD_TYPE=Release \
  -DLLVM_ENABLE_ASSERTIONS=On \
  -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir

Для запуска интерпретатора у нас есть инструмент перевода, позволяющий интерпретировать программы на языке StableHLO, написанные на MLIR.

stablehlo-translate --interpret <path/to/program>

Диалект переводчика

Диалект Interpreter содержит различные вспомогательные операции, связанные с интерпретатором. В частности, операция interpreter.run_parallel (см. InterpreterOps.td для семантики операций и примеров использования) позволяет интерпретировать операции дистрибутива, и планируется добавить больше вспомогательных операций в зависимости от потребностей сообщества.

Чек-диалект

Диалект Check используется для сравнения значений, полученных интерпретатором во время выполнения, с ожидаемыми значениями. Выходные данные программ StableHLO можно проверить с помощью различных операций проверки (см. CheckOps.td для семантики операций и примеров использования).

Программы для тестирования навыков письма

Мы используем инструмент lit из LLVM для запуска и сравнения с сгенерированным файлом, чтобы выявить различия в выводе интерпретатора (примеры тестов см. в stablehlo/tests/interpret ).

Тестирование AddOp (пример из interpret_add.mlir ):

// RUN: stablehlo-translate --interpret %s

func.func @add_op_scalar() {
  %0 = stablehlo.constant dense<2> : tensor<i4>
  %1 = stablehlo.constant dense<3> : tensor<i4>
  %2 = stablehlo.add %0, %1 : tensor<i4>
  check.expect_eq_const %2, dense<5> : tensor<i4>
  func.return
}

Для тестирования операций в категории «Распространение» необходимо запустить их с помощью утилиты interpreter.run_parallel .

Тестирование AllReduceOp (пример из файла all_reduce.mlir ):

// RUN: stablehlo-translate --interpret %s

module @cross_replica {
  func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
    %result = "stablehlo.all_reduce"(%operand) ({
      ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
        %0 = stablehlo.add %arg0, %arg1 : tensor<i64>
        stablehlo.return %0 : tensor<i64>
    }) {
      replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
      channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
    } : (tensor<4xi64>) -> tensor<4xi64>
    return %result : tensor<4xi64>
  }
  func.func public @main() {
    %inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
    %inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
    %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
      programs=[[@all_reduce], [@all_reduce]]
    } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
    check.expect_eq_const %results#0, dense<[6, 8, 10, 12]> : tensor<4xi64>
    check.expect_eq_const %results#1, dense<[6, 8, 10, 12]> : tensor<4xi64>
    func.return
  }
}

Отладка StableHLO

После выполнения шагов сборки StableHLO, исполняемые файлы инструментов StableHLO из stablehlo/tools должны находиться в каталоге /build/bin . Для пошагового выполнения кода можно использовать распространенные инструменты отладки, такие как GDB:

gdb --args ./build/bin/stablehlo-translate -allow-unregistered-dialect --interpret ./stablehlo/tests/interpret/<test>.mlir

Приложение

Преобразовать различные операции

# batch_norm_grad
hlo-expand --batch_norm_grad_expander <path/to/hlo_module>

# batch_norm_inference
hlo-expand --batch_norm_inference_expander <path/to/hlo_module>

# batch_norm_training
hlo-expand --batch_norm_training_expander <path/to/hlo_module>

# cholesky
hlo-expand --cholesky_expander <path/to/hlo_module>

# constant
# Supported in StableHLO interpreter.

# fft
# TBD

# iota
# Supported in StableHLO interpreter.

# rng
# TBD

# rng_bit_generator
# TBD

# triangular_solve
hlo-expand --triangular_solve_expander <path/to/hlo_module>

Convert Not In HLO Ops

# dot
mlir-hlo-opt -mhlo-legalize-dot-to-dot-general <path/to/input>

# einsum
mlir-hlo-opt -mhlo-legalize-einsum-to-dot-general <path/to/input>

# torch_index_select
mlir-hlo-opt -mhlo-legalize-torch-index-select-to-gather <path/to/input>

# unary_einsum
mlir-hlo-opt --canonicalize -mhlo-legalize-einsum-to-dot-general <path/to/input>

Категории операций StableHLO

Категории Мнемоника Общий
119
Управление потоком after_all, case, if, optimization_barrier, while 5
Перемещение данных broadcast_in_dim, concatenate, dynamic_slice, dynamic_update_slice, gather, pad, reshape, reverse, scatter, slice, sort, transpose 12
Распределение all_gather, all_reduce, all_to_all, collective_permute, infeed, outfeed, partition_id, recv, reduce_scatter, replica_id, send 11
Динамизм dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, get_dimension_size, real_dynamic_slice, set_dimension_size 9
Элементарно abs, add, and, atan2, bitcast_convert, cbrt, ceil, clamp, compare, complex, convert, cosine, count_leading_zeros, divide, exponential, exponential_minus_one, floor, imag, is_finite, log, log_plus_one, logistic, map, maximum, minimum, multiply, negate, not, or, popcnt, power, real, reduce_precision, remainder, round_nearest_afz, round_nearest_even, rsqrt, select, shift_left, shift_right_arithmetic, shift_right_logical, sign, sine, sqrt, subtract, tan, tanh, xor 48
Расширяемость custom_call, get_tuple_element, tuple 3
Разнообразный batch_norm_grad, batch_norm_inference, batch_norm_training, cholesky, constant, fft, iota, rng, rng_bit_generator, triangular_solve 10
Модульность вызов, функция, модуль, возврат 4
Не в HLO broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum 8
Квантование uniform_dequantize, uniform_quantize 2
Снижение свертка, dot_general, reduce, reduce_window, select_and_scatter 5