Formes et mise en page

Structure d'une opération XLA

Prenons l'exemple suivant de HLO :

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

Il comprend les composants suivants :

  • Nom de l'opération : add.936
    • Il s'agit du nom unique de l'opération.
  • Forme : bf16[8,1,1280,16384]
    • Il s'agit de la forme de sortie de l'opération. Ici, le dtype est bf16 et la forme est [8,1,1280,16384].
  • Mise en page (avec mosaïque) : 3,2,0,1:T(8,128)(2,1)
    • Cela décrit la façon dont le tableau est stocké en mémoire. 3,2,0,1 indique l'ordre des axes en mémoire (par exemple, par colonne, par ligne, etc.) et T(8,128)(2,1) indique le tiling et le padding utilisés.
    • La mise en page est facultative. Si aucune mosaïque n'est spécifiée, les dimensions sont supposées être ordonnées de la plus grande à la plus petite.
  • Opération : add
    • Opération en cours. Ici, il s'agit de Add, qui est également mentionné dans le nom de l'opération.
  • Arguments : exponential.183, broadcast.3115
    • Cette opération prend deux arguments, spécifiés avec leurs noms uniques.

Prenons un autre exemple, une opération de fusion :

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

En plus des composants décrits précédemment, il comprend les éléments suivants :

  • Attributs : kind et calls
    • Elles fournissent plus d'informations sur l'opération en cours, qui est la fusion dans ce cas.
  • Emplacement de la mémoire (identifiant de l'espace mémoire) : S(1)
    • Cela indique l'espace mémoire/l'emplacement où le tableau est stocké. S(1) indique ici que ce tableau réside dans la VMEM (sur une TPU).
  • Détails de la forme et de la mise en page pour l'argument d'entrée %fusion.32

Les sections suivantes décrivent les formes, la mise en page et les identifiants d'espace mémoire. Pour en savoir plus sur le mosaïquage, consultez Mise en page en mosaïque.

Formes

Le proto XLA ShapeProto (xla_data.proto) décrit le nombre de dimensions, la taille et le type de données d'un tableau à N dimensions (array en abrégé).

Terminologie, notation et conventions

  • Le nombre réel de dimensions d'un tableau correspond au nombre de dimensions dont la taille est supérieure à 1.

  • Les dimensions sont numérotées de 0 à N-1 pour un tableau à N dimensions. La taille d'une dimension est un nombre entier non négatif. En particulier, la taille 0 est valide. Les numéros de dimension sont des libellés arbitraires pour plus de commodité. L'ordre de ces nombres de dimensions n'implique pas un ordre mineur/majeur particulier dans la mise en page de la forme. La mise en page est déterminée par le proto LayoutProto.

  • Par convention, les dimensions sont listées par ordre croissant de numéro de dimension. Par exemple, pour un tableau à trois dimensions de taille [A x B x C], la dimension 0 a une taille de A, la dimension 1 a une taille de B et la dimension 2 a une taille de C.

    Certains utilitaires de XLA sont également compatibles avec l'indexation négative de type Python : la dimension -1 est la dernière dimension (équivalente à N-1 pour un tableau de dimension N). Par exemple, pour le tableau à trois dimensions décrit ci-dessus, la dimension -1 a une taille de C, la dimension -2 a une taille de B, et ainsi de suite.

  • Les tableaux à deux, trois et quatre dimensions sont souvent associés à des lettres spécifiques. Par exemple, pour un tableau 2D :

    • dimension 0 : y
    • Dimension 1 : x

    Pour un tableau 3D :

    • dimension 0 : z
    • Dimension 1 : y
    • dimension 2 : x

    Pour un tableau 4D :

    • dimension 0 : p
    • Dimension 1 : z
    • dimension 2 : y
    • Dimension 3 : x
  • Les fonctions de l'API XLA qui acceptent des dimensions le font par ordre croissant du numéro de dimension. Cela correspond à l'ordre utilisé lors de la transmission des dimensions en tant que initializer_list, par exemple :

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    créera une forme dont le tableau de tailles de dimension est constitué de la séquence [A, B, C, D].

Disposition

Le proto LayoutProto décrit la façon dont un tableau est représenté en mémoire. Il comprend les champs suivants :

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

Ordre des dimensions de mineur à majeur

Le seul champ obligatoire est minor_to_major. Ce champ décrit l'ordre mineur à majeur des dimensions d'une forme. Les valeurs de minor_to_major sont un ordre des dimensions du tableau (de 0 à N-1 pour un tableau à N dimensions), la première valeur étant la dimension la moins importante et la dernière valeur étant la dimension la plus importante. La dimension la moins importante est celle qui change le plus rapidement lorsque vous parcourez les éléments du tableau disposés dans la mémoire linéaire.

Par exemple, considérons le tableau 2D suivant de taille [2 x 3] :

a b c
d e f

Ici, la dimension 0 correspond à la taille 2 et la dimension 1 à la taille 3. Si le champ minor_to_major de la mise en page est [0, 1], la dimension 0 est la dimension la moins importante et la dimension 1 est la dimension la plus importante. Cela correspond à la mise en page suivante dans la mémoire linéaire :

a d b e c f

Cet ordre de dimensions mineur à majeur de 0 à N-1 est semblable à column-major (pour les dimensions bidimensionnelles). En supposant un ordre monotone des dimensions, une autre façon de désigner cette mise en page dans le code est simplement "dim 0 est mineur".

En revanche, si le champ minor_to_major de la mise en page est [1, 0], la mise en page dans la mémoire linéaire est la suivante :

a b c d e f

Un ordre de dimensions mineur à majeur de N-1 à 0 pour un tableau de dimensions N est semblable à row-major (pour les tableaux à deux dimensions). En supposant un ordre monotone des dimensions, une autre façon de désigner cette mise en page dans le code est simplement "la dimension 0 est principale".

Ordre par défaut du mineur au majeur

La mise en page par défaut des formes nouvellement créées est "l'ordre des dimensions est du plus grand au plus petit" (c'est-à-dire [N-1, ..., 0]).

Marges intérieures

Le champ tail_padding_alignment_in_elements définit l'alignement du tableau tiled en termes de nombre d'éléments. Après l'application du tiling, des éléments de remplissage sont ajoutés à la fin de la mise en page jusqu'à ce que le nombre total d'éléments soit un multiple de cette valeur.

Indexation dans les tableaux

La classe IndexUtil dans index_util.h fournit des utilitaires pour la conversion entre des index multidimensionnels et des index linéaires, en fonction d'une forme et d'une mise en page. Les index multidimensionnels incluent un index int64 pour chaque dimension. Les indices linéaires sont une valeur int64 unique qui indexe le tampon contenant le tableau. Consultez shape_util.h et layout_util.h dans le même répertoire pour découvrir des utilitaires qui simplifient la création et la manipulation de formes et de mises en page.

Identifiants d'espace mémoire

Dans HLO, chaque tableau peut être annoté avec un identifiant d'espace mémoire, écrit sous la forme S(n).

  • S(0) (souvent omis) désigne la mémoire à haut débit (HBM) de l'appareil.
  • S(1) représente la mémoire virtuelle (VMEM) sur l'appareil.
  • S(2), S(3), etc., correspondent à des espaces de mémoire supplémentaires spécifiques à l'appareil.
  • S(5) indique la mémoire de l'hôte.