Formas y diseño

Estructura de una operación de XLA

Considera el siguiente ejemplo de HLO:

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

Consta de los siguientes componentes:

  • Nombre de la operación: add.936
    • Es el nombre único de la operación.
  • Forma: bf16[8,1,1280,16384]
    • Esta es la forma de salida de la operación. Aquí, el dtype es bf16 y la forma es [8,1,1280,16384].
  • Diseño (con mosaicos): 3,2,0,1:T(8,128)(2,1)
    • Describe cómo se almacena el array en la memoria. 3,2,0,1 denota el orden de los ejes en la memoria (es decir, columna principal, fila principal, etcétera) y T(8,128)(2,1) denota el relleno y la división en mosaicos que se usan.
    • El diseño es opcional. Si no se especifica, no hay segmentación y se supone que las dimensiones están ordenadas de la más importante a la menos importante.
  • Operación: add
    • Operación que se está realizando. Aquí, es Add, que también se menciona en el nombre de la operación.
  • Argumentos: exponential.183, broadcast.3115
    • Esta operación toma dos argumentos, especificados con sus nombres únicos.

Veamos otro ejemplo, un Op de fusión:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

Además de los componentes descritos anteriormente, esto incluye lo siguiente:

  • Atributos: kind y calls
    • Proporcionan más información sobre la operación que se realiza, en este caso, la fusión.
  • Ubicación de la memoria (identificador de espacio de memoria): S(1)
    • Esto denota el espacio o la ubicación de memoria en el que se almacena el array. Aquí, S(1) denota que este array reside en la VMEM (en una TPU).
  • Detalles de la forma y el diseño del argumento de entrada %fusion.32

En las siguientes secciones, se describen las formas, el diseño y los identificadores de espacio de memoria. Puedes obtener más información sobre el mosaico en Diseño en mosaico.

Formas

El arquetipo ShapeProto de XLA (xla_data.proto) describe la cantidad de dimensiones, el tamaño y el tipo de datos de un array N-dimensional (array en forma abreviada).

Terminología, notación y convenciones

  • La cantidad real de dimensiones de un array es la cantidad de dimensiones que tienen un tamaño mayor que 1.

  • Las dimensiones se numeran del 0 al N-1 para un array de N dimensiones. El tamaño de una dimensión es un número entero no negativo. En particular, el tamaño 0 es válido. Los números de dimensión son etiquetas arbitrarias para mayor comodidad. El orden de estos números de dimensión no implica un orden secundario o principal en particular en el diseño de la forma. El diseño se determina según el arquetipo LayoutProto.

  • Por convención, las dimensiones se enumeran en orden creciente según su número. Por ejemplo, para un array tridimensional de tamaño [A x B x C], la dimensión 0 tiene el tamaño A, la dimensión 1 tiene el tamaño B y la dimensión 2 tiene el tamaño C.

    Algunas utilidades en XLA también admiten la indexación negativa similar a Python: la dimensión -1 es la última dimensión (equivalente a N-1 para un array de dimensión N). Por ejemplo, para el array de 3 dimensiones descrito anteriormente, la dimensión -1 tiene el tamaño C, la dimensión -2 tiene el tamaño B, y así sucesivamente.

  • Los arrays de dos, tres y cuatro dimensiones suelen tener letras específicas asociadas a las dimensiones. Por ejemplo, para un array 2D:

    • dimensión 0: y
    • Dimensión 1: x

    Para un array 3D, haz lo siguiente:

    • dimensión 0: z
    • Dimensión 1: y
    • Dimensión 2: x

    Para un array de 4 dimensiones, se aplica lo siguiente:

    • dimensión 0: p
    • Dimensión 1: z
    • Dimensión 2: y
    • dimensión 3: x
  • Las funciones de la API de XLA que toman dimensiones lo hacen en orden creciente del número de dimensión. Esto coincide con el orden que se usa cuando se pasan dimensiones como un initializer_list, p.ej.:

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    creará una forma cuyo array de tamaños de dimensión consta de la secuencia [A, B, C, D].

Diseño

El .proto LayoutProto describe cómo se representa un array en la memoria. Incluye los siguientes campos:

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

Ordenamiento de dimensiones de menor a mayor

El único campo obligatorio es minor_to_major. Este campo describe el orden de las dimensiones de menor a mayor dentro de una forma. Los valores en minor_to_major son un ordenamiento de las dimensiones del array (0 a N-1 para un array de dimensión N), en el que el primer valor es la dimensión menos importante y el último valor es la dimensión más importante. La dimensión más secundaria es la que cambia más rápidamente cuando se recorren los elementos del array dispuestos en la memoria lineal.

Por ejemplo, considera el siguiente array bidimensional de tamaño [2 x 3]:

a b c
d e f

Aquí, la dimensión 0 tiene un tamaño de 2, y la dimensión 1 tiene un tamaño de 3. Si el campo minor_to_major en el diseño es [0, 1], la dimensión 0 es la dimensión secundaria más importante y la dimensión 1 es la dimensión principal más importante. Esto corresponde al siguiente diseño en la memoria lineal:

a d b e c f

Este orden de dimensión de menor a mayor de 0 a N-1 es similar al orden principal por columna (para 2 dimensiones). Si se supone un ordenamiento monotónico de las dimensiones, otra forma en que podemos referirnos a este diseño en el código es simplemente "la dimensión 0 es secundaria".

Por otro lado, si el campo minor_to_major en el diseño es [1, 0], el diseño en la memoria lineal es el siguiente:

a b c d e f

Un orden de dimensión secundaria a principal de N-1 a 0 para un array de N dimensiones es similar a row-major (para 2 dimensiones). Si se supone un ordenamiento monótono de las dimensiones, otra forma en que podemos referirnos a este diseño en el código es simplemente "la dimensión 0 es principal".

Orden predeterminado de la versión secundaria a la principal

El diseño predeterminado para las formas recién creadas es "El orden de las dimensiones es de mayor a menor" (es decir, [N-1, ..., 0]).

Padding

El campo tail_padding_alignment_in_elements define la alineación del array en mosaico en términos de la cantidad de elementos. Después de aplicar el mosaico, se agregarán elementos con padding al final del diseño hasta que la cantidad total de elementos sea un múltiplo de este valor.

Indexación en arrays

La clase IndexUtil en index_util.h proporciona utilidades para convertir entre índices multidimensionales e índices lineales, dada una forma y un diseño. Los índices multidimensionales incluyen un índice int64 para cada dimensión. Los índices lineales son un solo valor de int64 que indexa el búfer que contiene el array. Consulta shape_util.h y layout_util.h en el mismo directorio para ver utilidades que simplifican la creación y manipulación de formas y diseños.

Identificadores de espacios de memoria

En HLO, cada array se puede anotar con un identificador de espacio de memoria, que se escribe como S(n).

  • S(0) (a menudo, se omite) denota la memoria de alto ancho de banda (HBM) del dispositivo.
  • S(1) representa la memoria virtual (VMEM) en el dispositivo.
  • S(2), S(3), etc., corresponden a espacios de memoria adicionales específicos del dispositivo.
  • S(5) indica la memoria del host.