Análise de indexação

Este documento descreve a análise de indexação HLO, que permite usar simbolicamente computar mapas de indexação para operações HLO. O mapa de indexação é uma função que mapeia índices de um tensor para os índices de outro, por exemplo, índices de uma saída de instrução HLO para índices de entradas de instrução HLO ou vice-versa.

Exemplo

Para uma transmissão de tensor<20xf32> para tensor<10x20x30xf32>

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

O mapa de indexação da saída para a entrada é (i, j, k) -> (j) para i in [0, 10], j in [0, 20] e k in [0, 30].

Motivação

A GPU XLA usa várias soluções personalizadas para raciocinar sobre esquemas de mesclagem, utilização de operando e ladrilhos (mais detalhes abaixo). O objetivo da análise de indexação é fornecer um componente reutilizável para esses casos de uso. Análise de indexação foi desenvolvido com base na infraestrutura de mapas físicos do MLIR e inclui semântica HLO.

Coalescência

Refletir sobre a coalescência de memória se torna viável para casos não triviais, quando sabemos quais elementos/fatias das entradas são lidos para computar um elemento do saída.

Utilização de operandos

A utilização de operandos na XLA indica o quanto cada entrada da instrução é usada, supondo que a saída seja totalmente usada. No momento, a utilização também não é calculada para um caso genérico. A análise de indexação permite calcular a utilização com precisão.

Divisão

Um bloco/corte é um subconjunto hiperretangular de um tensor parametrizado por deslocamentos, tamanhos e passos. A propagação de blocos é uma maneira de calcular os parâmetros de blocos do produtor/consumidor da operação usando os parâmetros de blocos da própria operação. já é um biblioteca que faz isso para softmax e ponto. A propagação de blocos pode ser mais genérica e robusta se for expressa por mapas de indexação.

Função e domínio

O mapa de indexação é uma função f(x) = f(d, r, rt) que mapeia um índice múltiplo d de um tensor A para elementos/intervalos de tensor B. O parâmetro r se refere aos intervalos de índices das dimensões presentes no tensor B, mas não no tensor A​. O parâmetro rt se refere aos valores de execução, por exemplo, índices para uma operação de coleta.

Por exemplo, se tivermos uma redução de tensor<2x4x8x16xf32> para tensor<4x8xf32>, o mapa de indexação da saída 2D para a entrada 4D será (d0, d1) -> (r0, d0, d1, r1), em que d_i são as variáveis de dimensão que correspondem aos índices do tensor de saída. Variáveis de intervalo codificadas por r_j vários valores, ou seja, para calcular um elemento (d0, d1) da saída, precisamos Elementos (r0, d0, d1, r1) da entrada, em que r0 in [0, 1] e r1 in [0, 15].

Esse mapeamento pode ser construído com base nos atributos das instruções HLO ou os mapeamentos de instruções não fundidas podem ser compostos para conseguir a indexação de uma fusão. O mapeamento também tem um domínio, que especifica para quais elementos do tensor o mapeamento existe.

f(x) s.t.

lb <= g(x) <= ub

Como queremos minimizar a recomputação, precisamos de uma biblioteca para computacionais. O XLA já depende do MLIR, então usamos mlir::AffineMap em vez de escrever uma outra biblioteca aritmética simbólica.

Um AffineMap típico tem esta aparência:

(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)

AffineMap tem dois tipos de parâmetros: dimensões e símbolos. O dimensions correspondem às variáveis de dimensão d, símbolos correspondem a as variáveis de intervalo r e as variáveis RT rt. AffineMap não contém nenhum metadados sobre intervalos de dimensões, por isso, temos que fornecer esses dados nós mesmos.

struct Interval {
 int64_t lower;
 int64_t upper;
};

// Dimension variable represents a dimension of a tensor or a GPU grid.
struct DimVar {
  Interval bounds;
};

// RangeVar variable represents a range of values, e.g. to compute a single
// element of the reduction's result we need a range of values from the input
// tensor.
struct RangeVar {
  Interval range;
};

// RTVar represents a runtime value, e.g. a dynamic offset in
// HLO dynamic-update-slice op.
struct RTVar {
  Interval feasible_values;
  const HloInstruction* hlo;
  // This is a map from the iteration space of the corresponding indexing map to
  // the iteration space of `hlo`. It shows what element of `hlo` we need to
  // extract to get the runtime value for the RTVar.
  mlir::AffineMap map;
};

class IndexingMap {
  mlir::AffineMap affine_map_;
  std::vector<DimVar> dim_vars_;
  std::vector<RangeVar> range_vars_;
  std::vector<RTVar> rt_vars_;
  llvm::DenseMap<mlir::AffineExpr, Interval> constraints_;
};

dim_vars_ codifica as restrições de caixa inclusivas para a dimensão. variáveis d do mapa de indexação, que geralmente coincidem com o forma do tensor de saída para operações como transpose, Reduce, elementwise, ponto, mas há algumas exceções, como HloConcatenateInstruction.

range_vars_ codifica os possíveis valores que os parâmetros r podem receber.

rt_vars_ armazena as instruções hlo associadas com o acesso e os valores viáveis no ambiente de execução. Por exemplo, o deslocamento é dinâmico para um HloDynamicSliceInstruction de 1D. O RTVar correspondente terá um HloInstruction* que produz um tensor de rank 0 com o padrão de acesso (d0) -> (), porque para cada elemento da saída, extraímos o mesmo elemento do tensor de deslocamento para calcular o índice da entrada. Também podemos supor que o deslocamento da fatia está sempre entre 0 e tensor_size - slice_size - 1.

Vamos estudar cada exemplo para entender o que todas as opções acima realmente significam.

Como indexar mapas para operações não mescladas

Elementwise

Para operações elementwise, o mapa de indexação é uma identidade.

  p0 = f32[10, 20] parameter(0)
  p1 = f32[10, 20] parameter(1)
  add = f32[10, 20] add(p0, p1)

A saída para os mapas de entrada:

  • saída -> input_i:
(d0, d1) -> (d0, d1)
domain:
d0 in [0, 9]
d1 in [0, 19]

Os mapas de entrada e saída

  • input_i -> saída:
(d0, d1) -> (d0, d1)
domain:
d0 in [0, 9]
d1 in [0, 19]

Removido

Transmissão significa que algumas das dimensões serão removidas quando mapearmos saída para entrada e adicionada quando mapeamos a entrada para a saída.

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

O mapa de saída para entrada:

(d0, d1, d2) -> (d1)
domain:
d0 in [0, 9]
d1 in [0, 19]
d2 in [0, 29]

O mapa de entrada para saída

(d0)[s0, s1] -> (s0, d0, s1)
domain:
d0 in [0, 19]
s0 in [0, 9]
s1 in [0, 29]

Observe que agora temos s no lado direito para a entrada para saída mapeamento. Esses são os símbolos que representam intervalos de valores. Por exemplo, nesse caso específico, cada elemento de entrada com índice d0 é mapeado para uma faixa de 10 x 1 x 30 da saída.

Constante e Iota

Convenientemente, eles não têm parâmetros de entrada, então não há nada para calcular a indexação.

DynamicSlice

O DynamicSlice é como o Slice, mas os deslocamentos são dinâmicos.

src = s32[2,2,258] parameter(0)
of1 = s32[] parameter(1)
of2 = s32[] parameter(2)
of3 = s32[] parameter(3)
ds = dynamic-slice(s32[2,2,258] src, s32[] of1, s32[] of2, s32[] of3), dynamic_slice_sizes={1, 2, 32}

O mapa de saída para entrada de src:

(d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2)
domain:
d0 in [0, 0]
d1 in [0, 1]
d2 in [0, 31]
s0 in [0, 1]
  hlo: of1 = s32[] parameter(1)
  (d0, d1, d2)  -> ()
s1 in [0, 0]
  hlo: of2 = s32[] parameter(2)
  (d0, d1, d2)  -> ()
s2 in [0, 226]
  hlo: of3 = s32[] parameter(3)
  (d0, d1, d2) -> ()

Observe que agora temos s no lado direito para o mapeamento de entrada para saída. Esses são os símbolos que representam valores de execução. Por exemplo, neste caso específico, para cada elemento da saída com índices d0, d1, d2, acessamos os deslocamentos de fatia of1, of2 e of3 para calcular o índice da entrada. Os intervalos das variáveis de execução são derivados, supondo que todo o segmento permanece nos limites.

A saída do mapa de entrada para of1, of2 e of3:

(d0, d1, d2)  -> ()
domain:
d0 in [0, 0]
d1 in [0, 1]
d2 in [0, 31]

DynamicUpdateSlice

src = s32[20,30] parameter(0)
upd = s32[5,10] parameter(1)
of1 = s32[] parameter(2)
of2 = s32[] parameter(3)
dus = s32[20,30] dynamic-update-slice(
    s32[20,30] src, s32[5,10] upd, s32[] of1, s32[] of2)

O mapa de saída para entrada de src é trivial. Ela pode ser mais precisa restringindo o domínio a índices não atualizados, mas indexando mapas no momento não aceitam restrições de igualdade.

(d0, d1) -> (d0, d1)
domain:
d0 in [0, 19]
d1 in [0, 29]

O mapa de saída para entrada de upd:

(d0, d1)[s0, s1]  -> (d0 - s0, d1 - s1)
domain:
d0 in [0, 19]
d1 in [0, 29]
s0 in [0, 15]
  hlo: of1 = s32[] parameter(2)
  (d0, d1)  -> ()
s1 in [0, 20]
  hlo: of2 = s32[] parameter(3)
  (d0, d1)  -> ()

Agora temos s no lado direito para o mapeamento de entrada para saída. Esses são os símbolos que representam valores de ambiente de execução. Por exemplo, nesta um caso específico para cada elemento da saída com índices d0, d1 que acessamos deslocamentos de fração of1 e of2 para calcular o índice da entrada. Os intervalos para as variáveis de execução são derivados ao presumir que o segmento inteiro permanece nos limites.

A saída do mapa de entrada para of1 e of2:

(d0, d1)  -> ()
domain:
d0 in [0, 19]
d1 in [0, 29]

Gather

Somente a coleta simplificada é aceita. Consulte [gather_simplifier].(https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/gather_simplifier.h).

operand = f32[33,76,70] parameter(0)
indices = s32[1806,2] parameter(1)
gather = f32[1806,7,8,4] gather(operand, indices),
  offset_dims={1,2,3},
  collapsed_slice_dims={},
  start_index_map={0,1},
  index_vector_dim=1,
  slice_sizes={7,8,4}

O mapa de saída para entrada de operand:


(d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3)
domain:
d0 in [0, 1805]
d1 in [0, 6]
d2 in [0, 7]
d3 in [0, 3]
s0 in [0, 26]
  hlo: indices = s32[1806,2]{1,0} parameter(1)
  (d0, d1, d2, d3) -> (d0, 0)
s1 in [0, 68]
  hlo: indices = s32[1806,2]{1,0} parameter(1)
  (d0, d1, d2, d3) -> (d0, 1)

Agora temos s no lado direito para o mapeamento de entrada para saída. Esses são os símbolos que representam valores de execução. Por exemplo, nesta um caso específico para cada elemento da saída com índices d0, d1, d2, d3 que extrair elementos (d0, 0) e (d0, 1) do tensor indices.

A saída do mapa de entrada para indices:

  (d0, d1, d2, d3)[s0] -> (d0, s0)
  domain:
  d0 in [0, 1805]
  d1 in [0, 6]
  d2 in [0, 7]
  d3 in [0, 3]
  s0 in [0, 1]

A variável de intervalo s0 mostra que precisamos de toda a linha (d0, *) do tensor indices para calcular um elemento da saída.

Transposição

O mapa de indexação para transposição é uma permutação das dimensões de entrada/saída.

p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}

O mapa de saída para entrada:

(d0, d1, d2, d3) -> (d0, d3, d1, d2)
domain:
d0 in [0, 2]
d1 in [0, 5]
d2 in [0, 127]
d3 in [0, 12287]

O mapa de entrada para saída:

(d0, d1, d2, d3) -> (d0, d2, d3, d1)
domain:
d0 in [0, 2]
d1 in [0, 12287]
d2 in [0, 5]
d3 in [0, 127]

Reverse

O mapa de indexação para reverter as dimensões revertidas para upper_bound(d_i) - d_i:

p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}

A saída para o mapa de entrada:

(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)
domain:
d0 in [0, 0]
d1 in [0, 16]
d2 in [0, 8]
d3 in [0, 8]

O mapa de entrada para saída:

(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)
domain:
d0 in [0, 0]
d1 in [0, 16]
d2 in [0, 8]
d3 in [0, 8]

(variação)Reduzir

A redução variável tem várias entradas e várias inits, o mapa da saída adiciona as dimensões reduzidas. Ele se comporta como um inverso de uma transmissão de certa forma.

p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
  dimensions={0}, to_apply=max

A saída para os mapas de entrada:

  • output -> input_j:
(d0)[s0] -> (s0, d0)
domain:
d0 in [0, 9]
s0 in [0, 255]
  • output -> init_j:
(d0) -> ()
domain:
d0 in [0, 9]

A entrada para mapas de saída:

  • input_i -> output_j:
(d0, d1) -> (d1)
domain:
d0 in [0, 255]
d1 in [0, 9]
  • init_i -> output_j:
()[s0] -> (s0)
domain:
s0 in [0, 9]

para i, j = 0, ... INPUT_COUNT.

Fatia

A indexação da saída para a entrada para fatias resulta em um mapa de indexação fragmentado que é válido para cada elemento da saída. O mapeamento da entrada para a saída é restringido a um intervalo de elementos com passo na entrada.

p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
  slice={[5:10:1], [3:20:7], [0:50:2]}

A saída para o mapa de entrada:

(d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2)
domain:
d0 in [0, 4]
d1 in [0, 2]
d2 in [0, 24]

O mapa de entrada para saída:

(d0, d1, d2) -> (d0 - 5, (d1 - 3) floordiv 7, d2 floordiv 2)
domain:
d0 in [5, 9]
d1 in [3, 17]
d2 in [0, 48]
(d1 - 3) mod 7 in [0, 0]
d2 mod 2 in [0, 0]

Redefinir

As reformulações têm diferentes sabores.

Forma de fechamento

Esta é uma remodelagem "linearização" de N-D para 1D.

p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)

A saída para o mapa de entrada:

(d0) -> (d0 floordiv 8, d0 mod 8)
domain:
d0 in [0, 31]

O mapa de entrada para saída:

(d0, d1) -> (d0 * 8 + d1)
domain:
d0 in [0, 3]
d1 in [0, 7]

Expandir forma

Esta é uma "forma de recolhimento" inversa ela remodela uma entrada 1D em uma saída N-D.

p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)

A saída para o mapa de entrada:

(d0, d1) -> (d0 * 8 + d1)
domain:
d0 in [0, 3]
d1 in [0, 7]

O mapa de entrada para saída:

(d0) -> (d0 floordiv 8, d0 mod 8)
domain:
d0 in [0, 31]

Reformulação genérica

Essas são as operações de remodelação que não podem ser representadas como uma única expansão ou recolher forma. Eles só podem ser representados como uma composição de 2 ou mais expandir ou recolher formas.

Exemplo 1: linearização-deslinearização.
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)

Essa remodelagem pode ser representada como uma composição de forma de colapso de tensor<4x8xf32> para tensor<32xf32> e, em seguida, uma expansão de forma para tensor<2x4x4xf32>.

A saída para o mapa de entrada:

(d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, d2 + (d1 mod 2) * 4)
domain:
d0 in [0, 1]
d1 in [0, 3]
d2 in [0, 3]

O mapa de entrada para saída:

(d0, d1) -> (d0 floordiv 2, d1 floordiv 4 + (d0 mod 2) * 2, d1 mod 4)
domain:
d0 in [0, 3]
d1 in [0, 7]
Exemplo 2: subformas expandidas e recolhidas
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)

Essa remodelação pode ser representada como uma composição de duas remodelações. A primeira reduz as dimensões mais externas tensor<4x8x12xf32> para tensor<32x12xf32> e a segunda expande a dimensão mais interna tensor<32x12xf32> para tensor<32x3x4xf32>.

O mapa de saída para entrada:

(d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2)
domain:
d0 in [0, 31]
d1 in [0, 2]
d2 in [0, 3]

O mapa de entrada para saída:

(d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4)
domain:
d0 in [0, 3]
d1 in [0, 7]
d2 in [0, 11]

Bitcast

Uma operação bitcast pode ser representada como sequência de transposição-reshape-transpose. Portanto, seus mapas de indexação são apenas uma composição de mapas de indexação para este sequência.

Concatenação

O mapeamento de saída para entrada para concat é definido para todas as entradas, domínios não sobrepostos, ou seja, apenas uma das entradas será usada por vez.

p0 = f32[2, 5, 7] parameter(0)
p1 = f32[2, 11, 7] parameter(1)
p2 = f32[2, 17, 7] parameter(2)
ROOT concat = f32[2, 33, 7] concatenate(f32[2, 5, 7] p0, f32[2, 11, 7] p1, f32[2, 17, 7] p2), dimensions={1}

A saída para entradas mapeia:

  • output -> entrada 1:
(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 4]
d2 in [0, 6]
  • saída -> entrada 2:
(d0, d1, d2) -> (d0, d1 - 5, d2)
domain:
d0 in [0, 1]
d1 in [5, 15]
d2 in [0, 6]
  • output -> entrada 3:
(d0, d1, d2) -> (d0, d1 - 16, d2)
domain:
d0 in [0, 1]
d1 in [16, 32]
d2 in [0, 6]

As entradas dos mapas de saída:

  • Entrada 1 -> saída:
(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 4]
d2 in [0, 6]
  • Entrada 2 -> saída:
(d0, d1, d2) -> (d0, d1 + 5, d2)
domain:
d0 in [0, 1]
d1 in [0, 10]
d2 in [0, 6]
  • Entrada 3 -> saída:
(d0, d1, d2) -> (d0, d1 + 16, d2)
domain:
d0 in [0, 1]
d1 in [0, 16]
d2 in [0, 6]

Ponto

Os mapas de indexação para ponto são muito semelhantes aos de redução.

p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
  lhs_batch_dims={0}, rhs_batch_dims={0},
  lhs_contracting_dims={2}, rhs_contracting_dims={1}

A saída para entradas mapeia:

  • saída -> input_1:
(d0, d1, d2)[s0] -> (d0, d1, s0)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 63]
s0 in [0, 255]
  • saída -> input_2:
(d0, d1, d2)[s0] -> (d0, s0, d2)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 63]
s0 in [0, 255]

As entradas dos mapas de saída:

  • input_1 -> saída:
(d0, d1, d2)[s0] -> (d0, d1, s0)
domain:
d0 in [0, 3]
d1 in [0, 127]
d2 in [0, 255]
s0 in [0, 63]
  • input_2 -> saída:
(d0, d1, d2)[s0] -> (d0, s0, d1)
domain:
d0 in [0, 3]
d1 in [0, 255]
d2 in [0, 63]
s0 in [0, 127]

Pad

A indexação de PadOp é inversa à indexação de SliceOp.

p0 = f32[4, 4] parameter(0)
p1 = f32[] parameter(1)
pad = f32[12, 16] pad(p0, p1), padding=1_4_1x4_8_0

A configuração de padding 1_4_1x4_8_0 indica lowPad_highPad_interiorPad_dim_0 x lowPad_highPad_interiorPad_dim_1.

A saída para mapas de entrada:

  • output -> entrada:
(d0, d1) -> ((d0 - 1) floordiv 2, d1 - 4)
domain:
d0 in [1, 7]
d1 in [4, 7]
(d0 - 1) mod 2 in [0, 0]
  • output -> init:
(d0, d1) -> ()
domain:
d0 in [0, 11]
d1 in [0, 15]

ReduceWindow

O ReduceWindow no XLA também executa o preenchimento. Portanto, os mapas de indexação podem ser calculados como uma composição de indexação de ReduceWindow que não faz nenhum preenchimento e indexação de PadOp.

c_inf = f32[] constant(-inf)
p0 = f32[1024, 514] parameter(0)
reduce-window = f32[1024, 3] reduce-window(p0, c_inf),
  window={size=1x512 pad=0_0x0_0}, to_apply=max

A saída para mapas de entrada:

  • output -> entrada:
(d0, d1)[s0] -> (d0, d1 + s0)
domain:
d0 in [0, 1023]
d1 in [0, 2]
s0 in [0, 511]
  • output -> init:
(d0, d1) -> ()
domain:
d0 in [0, 1023]
d1 in [0, 2]

Como indexar o Maps para a Fusion

O mapa de indexação da operação de fusão é uma composição de mapas de indexação para cada operação no aglomerado. Pode acontecer de algumas entradas serem lidas várias vezes com diferentes padrões de acesso.

Uma entrada, vários mapas de indexação

Confira um exemplo de p0 + transpose(p0).

f {
  p0 = f32[1000, 1000] parameter(0)
  transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
  ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}

Os mapas de indexação de saída para entrada de p0 serão (d0, d1) -> (d0, d1) e (d0, d1) -> (d1, d0). Isso significa que, para calcular um elemento da saída, talvez seja necessário ler o parâmetro de entrada duas vezes.

Um mapa de indexação de entrada sem duplicação

img

Há casos em que os mapas de indexação são iguais, mesmo que não seja imediatamente óbvio.

f {
  p0 = f32[20, 10, 50] parameter(0)
  lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
  lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
  lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
  rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
  rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
  rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
  ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}

Neste caso, o mapa de indexação de saída para entrada de p0 é apenas (d0, d1, d2) -> (d2, d0, d1).

Softmax

img

Os mapas de indexação de saída para entrada de parameter 0 para softmax:

(d0, d1, d2)[s0] -> (d0, d1, s0)
domain:
d0 in [0, 1]
d1 in [0, 64]
d2 in [0, 124]
s0 in [0, 124]

e

(d0, d1, d2) -> (d0, d1, d2)
domain:
d0 in [0, 1]
d1 in [0, 64]
d2 in [0, 124]

em que s0 se refere à dimensão mais interna da entrada.

Indexação do Map Simplifier

O simplificador padrão para upstream mlir::AffineMap não pode fazer suposições sobre os intervalos de dimensões/símbolos. Portanto, não é possível simplificar expressões com mod e div de maneira eficiente.

Podemos aproveitar o conhecimento sobre os limites inferior e superior das subexpressões nos mapas afins para simplificar ainda mais.

O simplificador pode reescrever as seguintes expressões.

  1. (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16) para d em [0, 6] x [0, 14] passa a ser (d0, d1) -> (d0, d1)
  2. O (d0, d1, d2) -> ((100d0 + 10d1 + d2) floorDiv 100, ((100d0 + 10d1 + d2) mod 100) floordiv 10, d2 mod 10) para di in [0, 9] se torna (d0, d1, d2) -> (d0, d1, d2).
  3. (d0, d1, d2) -> ((16d0 + 4d1 + d2) floordiv 8, (16d0 + 4d1 + d2) mod 8) para d_i in [0, 9] passa a ser (d0, d1, d2) -> (2d0 + (4d1 + d2) floordiv 8,(4d1 + d2) mod 8).
  4. (d0, d1) -> (-(-11d0 - d1 + 109) floordiv 11 + 9) para d em [0, 9] x [0, 10] se torna (d0, d1) -> (d0).

O simplificador de mapas de indexação nos permite entender que algumas das as remodelações no HLO cancelam umas às outras.

p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)

Depois da composição dos mapas de indexação e da simplificação, vamos ter

(d0, d1, d2) -> (d0, d1, d2).

A simplificação do mapa de indexação também simplifica as restrições.

  1. Restrições do tipo lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound são reescritas como updated_lower_bound <= affine_expr <= updated_upped_bound.
  2. Restrições que são sempre satisfeitas, por exemplo, d0 + s0 in [0, 20] para d0 in [0, 5] e s0 in [1, 3] são eliminados.
  3. Expressões afins nas restrições são otimizadas como o mapa acima.

Para ver mais exemplos, consulte indexing_map_test.cc.