Анализ индексации

В этом документе описывается анализ индексации HLO, который позволяет вам символически вычислять карты индексации для операций HLO. Карта индексирования — это функция, которая отображает индексы одного тензора в индексы другого, например индексы вывода инструкции HLO в индексы входов инструкции HLO или наоборот.

Пример

Для трансляции с tensor<20xf32> на tensor<10x20x30xf32>

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

карта индексации вывода на вход равна (i, j, k) -> (j) для i in [0, 10] , j in [0, 20] и k in [0, 30] .

Мотивация

XLA GPU использует несколько специальных решений для анализа объединений, использования операндов и схем листов (подробнее ниже). Целью анализа индексации является предоставление компонента многократного использования для таких случаев использования. Анализ индексирования построен на инфраструктуре Affine Map MLIR и добавляет семантику HLO.

Объединение

Рассуждения об объединении памяти становятся возможными в нетривиальных случаях, когда мы знаем, какие элементы/срезы входных данных считываются для вычисления элемента выходных данных.

Использование операндов

Использование операнда в XLA указывает, насколько сильно используется каждый вход инструкции, при условии, что ее выход полностью использован. В настоящее время использование также не рассчитывается для общего случая. Анализ индексации позволяет точно рассчитать загрузку.

Укладка плитки

Плитка/срез — это гиперпрямоугольное подмножество тензора, параметризованное смещениями, размерами и шагами. Распространение плиток — это способ вычисления параметров плитки производителя/потребителя операции с использованием параметров мозаики самой операции. Уже существует библиотека , которая делает это для softmax и dot. Распространение тайлов можно сделать более универсальным и надежным, если оно выражается через карты индексации.

Функция и домен

Карта индексации — это функция f ( d , s ), которая отображает мультииндекс d тензора A в элементы/диапазоны тензора B Параметр s относится к диапазонам индексов размерностей, которые присутствуют в тензоре B , но отсутствуют в тензоре A

Например, если у нас есть сокращение от tensor<2x4x8x16xf32> до tensor<4x8xf32> , то карта индексации от 2D-выхода до 4D-входа равна (d0, d1) -> (s0, d0, d1, s1) , где d_i — параметры размерности, соответствующие индексам выходного тензора. Параметры s_j кодируют несколько значений, т.е. для вычисления элемента (d0, d1) вывода нам нужны (s0, d0, d1, s1) элементы входа, где s0 in [0, 2) и s1 in [0, 16) .

Это отображение может быть построено на основе атрибутов инструкций HLO или может быть составлено отображение неслитых инструкций для получения индексации для слияния. Отображение также имеет область определения, которая указывает, для каких элементов тензора существует отображение.

ж ( д , с ) ст

фунт _d <= d <= ub _d

фунт _s <= s <= ub _s

фунт _g <= г <= ub _g

Поскольку мы хотим минимизировать повторные вычисления, нам нужна библиотека для символьных вычислений. XLA уже зависит от MLIR, поэтому мы используем mlir::AffineMap вместо написания библиотеки символьной арифметики.

Типичный AffineMap выглядит так

(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)

AffineMap имеет два типа параметров: размеры и символы , которые мы можем использовать для d и s соответственно. AffineMap не содержит метаданных о диапазонах измерений, поэтому нам приходится предоставлять эти данные самостоятельно.

struct Range {
 int64_t lower_bound;
 int64_t upper_bound;
};

struct IndexingMap {
 mlir::AffineMap affine_map;
 std::vector<Range> dim_ranges;
 std::vector<Range> symbol_ranges;
 llvm::DenseMap<mlir::AffineExpr, Range> expr_ranges;
};

dim_ranges кодирует ограничения инклюзивного блока для параметров измерения d карты индексации, которые обычно совпадают с формой выходного тензора для таких операций, как транспонирование, уменьшение, поэлементное, точечное, но есть некоторые исключения, такие как HloConcatenateInstruction .

symbol_ranges кодируют возможные значения, которые могут принимать параметры .

Давайте рассмотрим пример, чтобы понять, что на самом деле означает все вышеперечисленное.

Индексирование карт для Unfused Ops

Поэлементно

Для поэлементных операций карта индексации является идентификатором.

  p0 = f32[10, 20] parameter(0)
  p1 = f32[10, 20] parameter(1)
  add = f32[10, 20] add(p0, p1)

Вывод для ввода карт:

  • вывод -> вход_i:
(d0, d1) -> (d0, d1)
domain:
d0 in [0, 19]
d1 in [0, 19]

Входные и выходные карты

  • input_i -> вывод:
(d0, d1) -> (d0, d1)
domain:
d0 in [0, 19]
d1 in [0, 19]

Транслировать

Широковещательная передача означает, что некоторые измерения будут удалены, когда мы сопоставляем вывод с вводом, и добавлены, когда мы сопоставляем ввод с выводом.

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

Вывод на входную карту:

(d0, d1, d2) -> (d1)
domain:
d0 in [0, 9]
d1 in [0, 19]
d2 in [0, 29]

Входная и выходная карта

(d0)[s0, s1] -> (s0, d0, s1)
domain:
d0 in [0, 19]
s0 in [0, 9]
s1 in [0, 29]

Обратите внимание, что теперь у нас есть s справа для сопоставления ввода-вывода. Это символы, обозначающие диапазоны значений. Например, в этом конкретном случае каждый элемент ввода с индексом d0 сопоставляется с срезом вывода размером 10x1x30.

Константа и Йота

Удобно, что у них нет входных параметров, поэтому индексацию вычислять не для чего.

Транспонировать

Карта индексирования для транспонирования представляет собой перестановку измерений ввода/вывода.

p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}

Вывод на входную карту:

(d0, d1, d2, d3) -> (d0, d3, d1, d2)
domain:
d0 in [0, 2]
d1 in [0, 5]
d2 in [0, 127]
d3 in [0, 12287]

Входная и выходная карта:

(d0, d1, d2, d3) -> (d0, d2, d3, d1)
domain:
d0 in [0, 2]
d1 in [0, 12287]
d2 in [0, 5]
d3 in [0, 127]

Обеспечить регресс

Карта индексирования для обратного изменения изменяет возвращенные размеры на upper_bound(d_i) - d_i :

p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}

Вывод на входную карту:

(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)
domain:
d0 in [0, 0]
d1 in [0, 16]
d2 in [0, 8]
d3 in [0, 8]

Входная и выходная карта:

(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)
domain:
d0 in [0, 0]
d1 in [0, 16]
d2 in [0, 8]
d3 in [0, 8]

(Вариативный)Уменьшить

Вариативное сокращение имеет несколько входов и несколько инициализаций, карта от выхода ко входу добавляет уменьшенные измерения. Таким образом, в некотором смысле он ведет себя как обратный широковещанию.

p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
  dimensions={0}, to_apply=min

Вывод для ввода карт:

  • вывод -> вход_j:
(d0)[s0] -> (s0, d0)
domain:
d0 in [0, 9]
s0 in [0, 255]
  • вывод -> init_j:
(d0) -> ()
domain:
d0 in [0, 9]

Входные и выходные карты:

  • вход_я -> выход_j:
(d0, d1) -> (d1)
domain:
d0 in [0, 255]
d1 in [0, 9]
  • init_i -> output_j:
()[s0] -> (s0)
domain:
s0 in [0, 9]

для я, j = 0, ... INPUT_COUNT.

Кусочек

Индексация от вывода ко входу для среза приводит к созданию карты пошаговой индексации, которая действительна для каждого элемента вывода. Сопоставление входа с выходом ограничено шагомерным диапазоном элементов на входе.

p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
  slice={[5:10:1], [3:20:7], [0:50:2]}

Вывод на входную карту:

(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)
domain:
d0 in [0, 4]
d1 in [0, 2]
d2 in [0, 24]

Входная и выходная карта:

TBD : индексация ввода-вывода

Изменить форму

Изменения бывают разных вкусов.

Свернуть фигуру

Это «линеаризация» изменения формы от ND до 1D.

p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)

Вывод на входную карту:

(d0) -> (d0 floordiv 8, d0 mod 8)
domain:
d0 in [0, 31]

Входная и выходная карта:

(d0) -> (d0 floordiv 8, d0 mod 8)
domain:
d0 in [0, 31]

Развернуть фигуру

Это обратная операция «свертывания формы», она преобразует входной сигнал 1D в выходной сигнал ND.

p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)

Вывод на входную карту:

(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)
domain:
d0 in [0, 31]
d1 in [0, 2]
d2 in [0, 3]

Входная и выходная карта:

(d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4)
domain:
d0 in [0, 3]
d1 in [0, 7]
d2 in [0, 11]

Общее изменение формы

Это операции изменения формы, которые нельзя представить как одну фигуру развертывания или свертывания. Их можно представить только как композицию из двух или более фигур развертывания или свертывания.

Пример 1: Линеаризация-делинеаризация.
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)

Это изменение формы можно представить как комбинацию формы свертывания tensor<4x8xf32> в tensor<32xf32> и последующего расширения формы до tensor<2x4x4xf32> .

Вывод на входную карту:

(d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, d2 + (d1 mod 2) * 4) 
domain:
d0 in [0, 1]
d1 in [0, 3]
d2 in [0, 3]

Входная и выходная карта:

(d0, d1) -> (d0 floordiv 2, d1 floordiv 4 + (d0 mod 2) * 2, d1 mod 4)
domain:
d0 in [0, 3]
d1 in [0, 7]
Пример 2. Развернутые и свернутые подфигуры
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)

Это изменение формы можно представить как композицию двух изменений. Первый сжимает tensor<4x8x12xf32> самых внешних измерений до tensor<32x12xf32> , а второй расширяет tensor<32x12xf32> самых внутренних измерений в tensor<32x3x4xf32> .

Вывод на входную карту:

(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)
domain:
d0 in [0, 31]
d1 in [0, 2]
d2 in [0, 3]

Входная и выходная карта:

(d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4)
domain:
d0 in [0, 3]
d1 in [0, 7]
d2 in [0, 11]

Биткаст

Операцию битового вещания можно представить как последовательность транспонирования-изменения-транспонирования . Следовательно, его карты индексации представляют собой просто композицию карт индексации для этой последовательности.

Объединить

Сопоставление выходов и входов для concat определяется для всех входов, но с непересекающимися доменами, т. е. одновременно будет использоваться только один из входов.

p0 = f32[2, 5, 7] parameter(0)
p1 = f32[2, 11, 7] parameter(1)
p2 = f32[2, 17, 7] parameter(2)
ROOT concat = f32[2, 33, 7] concatenate(f32[2, 5, 7] p0, f32[2, 11, 7] p1, f32[2, 17, 7] p2), dimensions={1}

Вывод на входные карты:

  • выход -> вход 1:
(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 4]
d2 in [0, 6]
  • выход -> вход 2:
(d0, d1, d2) -> (d0, d1 - 5, d2)
domain:
d0 in [0, 1]
d1 in [5, 15]
d2 in [0, 6]
  • выход -> вход 3:
(d0, d1, d2) -> (d0, d1 - 16, d2)
domain:
d0 in [0, 1]
d1 in [16, 32]
d2 in [0, 6]

Входные данные для выходных карт:

  • вход 1 -> выход:
(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 4]
d2 in [0, 6]
  • вход 2 -> выход:
(d0, d1, d2) -> (d0, d1 + 5, d2)
domain:
d0 in [0, 1]
d1 in [0, 10]
d2 in [0, 6]
  • вход 3 -> выход:
(d0, d1, d2) -> (d0, d1 + 16, d2)
domain:
d0 in [0, 1]
d1 in [0, 16]
d2 in [0, 6]

Точка

Карты индексации для точки очень похожи на карты для уменьшения.

p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
  lhs_batch_dims={0}, rhs_batch_dims={0},
  lhs_contracting_dims={2}, rhs_contracting_dims={1}

Вывод на входные карты:

  • вывод -> вход_1:
(d0, d1, d2)[s0] -> (d0, d1, s0)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 63]
s0 in [0, 255]
  • вывод -> вход_2:
(d0, d1, d2)[s0] -> (d0, s0, d2)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 63]
s0 in [0, 255]

Входные данные для выходных карт:

  • вход_1 -> вывод:
(d0, d1, d2) -> (d0, d1, s0)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 255]
s0 in [0, 63]
  • вход_2 -> вывод:
(d0, d1, d2) -> (d0, s_0, d1)
domain:
d0 in [0, 3]
d1 in [0, 255]
d2 in [0, 63]
s0 in [0, 127]

Подушка

Индексация PadOp обратна индексации SliceOp.

p0 = f32[4, 4] parameter(0)
p1 = f32[] parameter(1)
pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0

Конфигурация заполнения 1_4_1x4_8_0 обозначает lowPad_highPad_interiorPad_dim_0 x lowPad_highPad_interiorPad_dim_1 .

Вывод для ввода карт:

  • вывод -> ввод:
(d0, d1) -> ((d0 - 1) floordiv 2, d1 - 4)
domain:
d0 in [1, 7]
d1 in [4, 7]
(d0 - 1) mod 2 in [0, 0]
  • вывод -> инициализация:
(d0, d1) -> ()
domain:
d0 in [0, 11]
d1 in [0, 15]

Уменьшить окно

РедуцированиеWindow в XLA также выполняет заполнение. Таким образом, карты индексации могут быть вычислены как комбинация индексации РедукторВиндов, которая не выполняет никаких дополнений, и индексации PadOp.

c_inf = f32[] constant(-inf)
p0 = f32[1024, 514] parameter(0)
reduce-window = f32[1024, 3] reduce-window(p0, c_inf),
  window={size=1x512 pad=0_0x0_0}, to_apply=max

Вывод для ввода карт:

  • вывод -> ввод:
(d0, d1)[s0] -> (d0, d1 + s0)
domain:
d0 in [0, 1023]
d1 in [0, 2]
s0 in [0, 511]
  • вывод -> инициализация:
(d0, d1) -> ()
domain:
d0 in [0, 1023]
d1 in [0, 2]

Индексирование карт для Fusion

Карта индексации для операции слияния — это композиция карт индексации для каждой операции в кластере. Может случиться так, что некоторые входные данные читаются несколько раз с разными шаблонами доступа.

Один вход, несколько карт индексации

Вот пример для p0 + transpose(p0) .

f {
  p0 = f32[1000, 1000] parameter(0)
  transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
  ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}

Карты индексации вывода-ввода для p0 будут (d0, d1) -> (d0, d1) и (d0, d1) -> (d1, d0) . Это означает, что для вычисления одного элемента вывода нам может потребоваться дважды прочитать входной параметр.

Один вход, дедуплицированная карта индексации

изображение

Бывают случаи, когда карты индексации на самом деле одинаковы, хотя это не сразу очевидно.

f {
  p0 = f32[20, 10, 50] parameter(0)
  lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
  lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
  lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
  rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
  rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
  rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
  ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}

Карта индексации вывода-ввода для p0 в этом случае равна просто (d0, d1, d2) -> (d2, d0, d1) .

Софтмакс

изображение

Сопоставления индексации вывода и ввода для parameter 0 для softmax:

(d0, d1, d2)[s0] -> (d0, d1, s0)
domain:
d0 in [0, 1]
d1 in [0, 64]
d2 in [0, 124]
s0 in [0, 124]

и

(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 64]
d2 in [0, 124]

где s_0 относится к самому внутреннему измерению ввода.

Упроститель карт индексирования

Упроститель по умолчанию для mlir::AffineMap в исходном коде не может делать никаких предположений о диапазонах размеров/символов. Следовательно, он не может эффективно упрощать выражения с помощью mod и div .

Мы можем использовать знания о нижних и верхних границах подвыражений в аффинных картах, чтобы еще больше их упростить.

Упроститель может переписать следующие выражения.

  1. (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16) для d в [0, 6] x [0, 14] становится (d0, d1) -> (d0, d1)
  2. (d0, d1, d2) -> ((100d0 + 10d1 + d2) floorDiv 100, ((100d0 + 10d1 + d2) mod 100) floordiv 10, d2 mod 10) для di in [0, 9] становится (d0, d1, d2) -> (d0, d1, d2) .
  3. (d0, d1, d2) -> ((16d0 + 4d1 + d2) floordiv 8, (16d0 + 4d1 + d2) mod 8) для d_i in [0, 9] становится (d0, d1, d2) -> (2d0 + (4d1 + d2) floordiv 8,(4d1 + d2) mod 8) .
  4. (d0, d1) -> (-(-11d0 - d1 + 109) floordiv 11 + 9) для d в [0, 9] x [0, 10] становится (d0, d1) -> (d0) .

Упроститель карты индексирования позволяет нам понять, что некоторые цепочки изменений в HLO отменяют друг друга.

p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)

После составления карт индексации и их упрощения получим

(d0, d1, d2) -> (d0, d1, d2) .

Упрощение карты индексации также упрощает ограничения.

  1. Ограничения типа lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound перезаписываются как updated_lower_bound <= affine_expr <= updated_upped_bound .
  2. Ограничения, которые всегда выполняются, например d0 + s0 in [0, 20] для d0 in [0, 5] и s0 in [1, 3] устраняются.
  3. Аффинные выражения в ограничениях оптимизируются как аффинная карта индексации, приведенная выше.