Динамизм в StableHLO

Текущее состояние динамизма более формально изложено в документе RFC по динамизму . На этой странице представлен общий обзор RFC, а также обсуждаются важные API и инструменты для взаимодействия с динамическими программами.

Обзор терминологии и поддержки динамизма

Сначала рассмотрим несколько терминов, которые появятся в этом документе, а также кратко расскажем об их поддержке в StableHLO:

Динамические измерения

Динамические измерения относятся к любому измерению, размер которого неизвестен. В StableHLO динамические измерения обозначаются символом ? , то есть tensor<16x?xf32> .

Ограниченный динамизм

Ограниченный динамизм относится к динамическому измерению, значение которого имеет известную верхнюю границу. Обычно это полезно для дополнения тензора во время выполнения. В StableHLO ограниченный динамизм представлен с помощью #stablehlo.bounds в качестве тензорного кодирования, то есть тензор ранга 2, где одно динамическое измерение ограничено 16, а другое не ограничено, можно представить как tensor<?x?xf32, #stablehlo.bounds<16, ?>> .

StableHLO способен представлять ограниченный динамизм, но поддержка фреймворка ограничена, начиная с TensorFlow и с некоторой поддержкой в ​​PyTorch/XLA.

Безграничный динамизм

Неограниченная динамизация, как следует из названия, относится к динамическому измерению без известных ограничений по размеру. Этот тип динамизма очень распространён в StableHLO с поддержкой JAX, PyTorch/XLA и TF, и часто используется для экспорта моделей с динамическим размером пакета или длиной последовательности.

В StableHLO мы просто опускаем кодирование границ для этой формы динамизма, т. е. tensor<?x?xf32> .

Полиморфизм формы

Полиморфизм форм — термин, унаследованный нами от JAX .

Существует два ключевых следствия для формирования полиморфизма:

  1. Вся динамика программы обусловлена ​​ее входными аргументами.
  2. Весь динамизм относится только к тензорным формам , т.е. не зависит от данных.

Используя эти два правила, как только станут известны статические формы программы, мы можем взять динамическую программу и полностью преобразовать ее в статическую программу для компиляции (см. «Проходы компилятора для уточнения динамических программ» ).

Обычно полиморфизм форм использует неограниченную динамичность: если известные формы аргументов могут привести к полностью статической программе, нет необходимости гадать, как ограничивать значения.

Динамизм, зависящий от данных

Динамизм, зависящий от данных, относится к динамическим размерам измерений, относящимся к данным внутри тензора. Каноническим примером является nonzeros функция, которая возвращает индексы всех элементов тензора, равных 0 . Форма не может быть определена без оценки данных, но её часто можно скомпилировать с использованием ограниченного динамизма, затратив дополнительную память на потенциальный размер выходного тензора.

Многие динамические операции, зависящие от данных, можно смоделировать с помощью ограниченной динамизации, где указывается верхняя граница размера тензора, и аппаратное обеспечение обычно реализует это посредством тензорного заполнения. В настоящее время в PyTorch/XLA и TensorFlow имеется некоторая поддержка динамизации, зависящей от данных, но JAX в настоящее время не трассирует операции, которые приводят к динамизации, зависящей от данных.

Экспорт программ с динамическими измерениями

Информацию о том, как экспортировать программы с динамическими размерами пакетов или длиной последовательностей, см. в наших руководствах по StableHLO:

Компилятор проходит уточнение динамических программ

Удалить динамизм проходного конвейера

Есть несколько полезных проходов для уточнения форм, для удобства они все объединены в конвейер проходов createStablehloRemoveDynamismPipeline :

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Индивидуальные проходы для улучшения динамики

По отдельности, проходы, которые, как правило, полезны для уточнения формы, следующие:

  • stablehlo-refine-arguments для замены входных аргументов конкретными тензорными типами.
  • stablehlo-refine-shapes для распространения информации о форме нового входного аргумента по всей программе.
  • stablehlo-canonicalize-dynamism для замены динамических операций их статическими вариантами.
  • stablehlo-check-shape-assertions для проверки и удаления пользовательских вызовов утверждений формы.

Актуальную информацию и примеры смотрите в связанной документации.

Пример: Чем полезен динамизм и как его можно использовать?

Динамизм имеет множество применений, здесь мы в основном сосредоточимся на общем случае использования полиморфизма форм — создании гибкого экспортируемого представления модели, обычно используемого для представления динамического размера партии или длины последовательности.

Статическая модель add_one

Для демонстрации этого мы воспользуемся следующей простой моделью add_one :

def add_one(x):
  return x + 1

При трассировке с использованием tensor<4xf32> получим следующую программу StableHLO:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Эта модель будет работать только для входных аргументов, имеющих форму tensor<4xf32> . Если мы когда-либо изменим размер пакета или длину последовательности, нам придётся заново трассировать исходный код и переходить на StableHLO, и нет никакой гарантии, что у нас вообще останется доступ к исходному коду!

Динамическая модель add_one

Здесь в игру вступает полиморфный динамизм формы. Вместо этого JAX и PyTorch/XLA могут генерировать модель add_one с динамически корректным IR, которая будет транслировать константу, соответствующую динамической входной форме, следующим образом:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Такое представление модели гораздо более гибко и допускает отложенное задание таких значений, как размер партии или длина последовательности. Эту модель можно развернуть на платформах с поддержкой динамических фигур (например, AI Edge ), или доработать её с помощью динамических проходов, упомянутых в этой документации.

Уточнение динамической модели

Например, следующий порядок проходов может полностью усовершенствовать эту программу:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Постепенно программа преобразуется следующим образом:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}