Formas e layout

Estrutura de uma operação XLA

Considere um exemplo de HLO:

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

Isso consiste nos seguintes componentes:

  • Nome da operação: add.936
    • Esse é o nome exclusivo da operação.
  • Forma: bf16[8,1,1280,16384]
    • Este é o formato de saída da operação. Aqui, o dtype é bf16 e o formato é [8,1,1280,16384].
  • Layout (com mosaico): 3,2,0,1:T(8,128)(2,1)
    • Isso descreve como a matriz é armazenada na memória. 3,2,0,1 indica a ordem dos eixos na memória (por exemplo, coluna principal, linha principal etc.), e T(8,128)(2,1) indica o ajuste de blocos e o padding usados.
    • O layout é opcional. Se não for especificado, não haverá mosaico, e as dimensões serão consideradas ordenadas da mais importante à menos importante.
  • Operação: add
    • A operação que está sendo realizada. Aqui, é Add, que também é mencionado no nome da operação.
  • Argumentos: exponential.183, broadcast.3115
    • Essa operação usa dois argumentos, especificados com nomes exclusivos.

Vamos considerar outro exemplo, uma operação de fusão:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

Além dos componentes descritos anteriormente, ela consiste em:

  • Atributos: kind e calls
    • Elas fornecem mais informações sobre a operação que está sendo realizada, neste caso, a fusão.
  • Localização da memória (identificador de espaço da memória): S(1)
    • Isso indica o espaço/local de memória em que a matriz está armazenada. S(1) aqui indica que essa matriz reside na VMEM (em uma TPU).
  • Detalhes de forma e layout para o argumento de entrada %fusion.32

As seções a seguir descrevem Layout, Identificadores de espaço de memória e formas. Saiba mais sobre o mosaico em Layout em mosaico.

Formas

O proto XLA ShapeProto (xla_data.proto) descreve o número de dimensões, o tamanho e o tipo de dados de uma matriz N-dimensional (matriz, em resumo).

Terminologia, notação e convenções

  • O número verdadeiro de dimensões de uma matriz é o número de dimensões que têm um tamanho maior que 1.

  • As dimensões são numeradas de 0 até N-1 para uma matriz dimensional N. O tamanho de uma dimensão é um número inteiro não negativo. Em particular, o tamanho 0 é válido. Os números das dimensões são rótulos arbitrários para conveniência. A ordem desses números de dimensão não implica uma ordenação específica de menor/maior no layout da forma. O layout é determinado pelo proto LayoutProto.

  • Por convenção, as dimensões são listadas em ordem crescente de número. Por exemplo, para uma matriz tridimensional de tamanho [A x B x C], a dimensão 0 tem tamanho A, a dimensão 1 tem tamanho B e a dimensão 2 tem tamanho C.

    Alguns utilitários no XLA também aceitam indexação negativa semelhante ao Python: a dimensão -1 é a última dimensão (equivalente a N-1 para uma matriz dimensional N). Por exemplo, para a matriz tridimensional descrita acima, a dimensão -1 tem tamanho C, a dimensão -2 tem tamanho B e assim por diante.

  • Matrizes bidimensionais, tridimensionais e quadridimensionais costumam ter letras específicas associadas às dimensões. Por exemplo, para uma matriz 2D:

    • dimensão 0: y
    • dimensão 1: x

    Para uma matriz 3D:

    • dimensão 0: z
    • dimensão 1: y
    • Dimensão 2: x

    Para uma matriz 4D:

    • dimensão 0: p
    • dimensão 1: z
    • Dimensão 2: y
    • Dimensão 3: x
  • As funções na API XLA que usam dimensões fazem isso em ordem crescente de número de dimensão. Isso corresponde à ordem usada ao transmitir dimensões como um initializer_list. Por exemplo:

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    vai criar uma forma cuja matriz de tamanho de dimensão consiste na sequência [A, B, C, D].

Layout

O proto LayoutProto descreve como uma matriz é representada na memória. Ele inclui os seguintes campos:

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

Ordenação de dimensão de secundária para principal

O único campo obrigatório é minor_to_major. Esse campo descreve a ordenação de dimensões de menor para maior em uma forma. Os valores em minor_to_major são uma ordenação das dimensões da matriz (0 a N-1 para uma matriz de N dimensões), sendo o primeiro valor a dimensão mais secundária e o último valor a dimensão mais principal. A dimensão mais secundária é aquela que muda mais rapidamente ao percorrer os elementos da matriz dispostos na memória linear.

Por exemplo, considere a seguinte matriz 2D de tamanho [2 x 3]:

a b c
d e f

Aqui, a dimensão 0 é o tamanho 2, e a dimensão 1 é o tamanho 3. Se o campo minor_to_major no layout for [0, 1], a dimensão 0 será a mais secundária e a dimensão 1 será a mais principal. Isso corresponde ao seguinte layout na memória linear:

a d b e c f

Essa ordem de dimensão de menor para maior de 0 até N-1 é semelhante à ordem das colunas (para bidimensionais). Supondo uma ordenação monotônica de dimensões, outra maneira de se referir a esse layout no código é simplesmente "a dimensão 0 é secundária".

Por outro lado, se o campo minor_to_major no layout for [1, 0], o layout na memória linear será:

a b c d e f

Uma ordem de dimensão de secundária para principal de N-1 até 0 para uma matriz dimensional N é semelhante a linha principal (para 2 dimensões). Supondo uma ordenação monotônica de dimensões, outra maneira de se referir a esse layout no código é simplesmente "a dimensão 0 é principal".

Ordem padrão de menor para maior

O layout padrão para formas recém-criadas é "a ordem das dimensões é de maior para menor" (ou seja, [N-1, ..., 0]).

Padding

O campo tail_padding_alignment_in_elements define o alinhamento da matriz em blocos em termos do número de elementos. Depois de aplicar o mosaico, os elementos com padding serão adicionados ao final do layout até que o número total de elementos seja um múltiplo desse valor.

Indexação em matrizes

A classe IndexUtil em index_util.h fornece utilitários para conversão entre índices multidimensionais e lineares dado um formato e um layout. Os índices multidimensionais incluem um índice int64 para cada dimensão. Os índices lineares são um único valor int64 que indexa o buffer que contém a matriz. Consulte shape_util.h e layout_util.h no mesmo diretório para ver utilitários que simplificam a criação e a manipulação de formas e layouts.

Identificadores de espaço de memória

Em HLO, cada matriz pode ser anotada com um identificador de espaço de memória, escrito como S(n).

  • S(0) (geralmente omitido) indica a memória de alta largura de banda (HBM) do dispositivo.
  • S(1) representa a memória virtual (VMEM) no dispositivo.
  • S(2), S(3) etc., correspondem a espaços de memória adicionais específicos do dispositivo.
  • S(5) indica a memória do host.