形状和布局

XLA 操作的结构

请看一个 HLO 示例:

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

这包括以下组件:

  • 操作名称:add.936
    • 这是操作的唯一名称。
  • 形状:bf16[8,1,1280,16384]
    • 这是相应操作的输出形状。此处的 dtype 为 bf16,形状为 [8,1,1280,16384]
  • 布局(采用平铺):3,2,0,1:T(8,128)(2,1)
    • 此图描述了数组在内存中的存储方式。3,2,0,1 表示内存中轴的顺序(即列优先、行优先等),T(8,128)(2,1) 表示所用的平铺和填充。
    • 布局是可选的。如果未指定,则不进行平铺,并且假定维度按从最主要到最次要的顺序排列。
  • 操作:add
    • 正在执行的操作。在此处,它是 Add,在操作名称中也有提及。
  • 实参:exponential.183broadcast.3115
    • 此操作接受两个实参,并使用其唯一名称指定。

我们再来看一个示例,即融合操作:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

除了之前描述的组件之外,该架构还包含:

  • 属性:kindcalls
    • 这些参数可提供有关正在执行的操作(在本例中为融合)的更多信息。
  • 内存位置(内存空间标识符):S(1)
    • 表示数组的存储内存空间/位置。S(1) 表示此数组位于 VMEM 中(在 TPU 上)。
  • 输入实参 %fusion.32 的形状和布局详细信息

以下部分将介绍形状、布局内存空间标识符。如需详细了解平铺,请参阅平铺布局

形状

XLA ShapeProto proto (xla_data.proto) 描述了 N 维数组(简称数组)的维度数、大小和数据类型。

术语、符号和惯例

  • 数组的真实维度数是指大小大于 1 的维度数。

  • 对于 N 维数组,维度的编号从 0N-1。 维度的大小是一个非负整数。特别是,大小为 0 是有效的。维度编号是为方便起见而随意指定的标签。这些维度编号的顺序并不意味着形状布局中存在特定的次要/主要顺序。布局由 LayoutProto proto 确定。

  • 按照惯例,维度按维度编号的升序排列。例如,对于大小为 [A x B x C] 的 3 维数组,维度 0 的大小为 A,维度 1 的大小为 B,维度 2 的大小为 C

    XLA 中的某些实用程序还支持类似 Python 的负索引:维度 -1 是最后一个维度(对于 N 维数组,相当于 N-1)。例如,对于上述 3 维数组,维度 -1 的大小为 C,维度 -2 的大小为 B,依此类推。

  • 二维、三维和四维数组通常具有与维度关联的特定字母。例如,对于一个二维数组:

    • 维度 0:y
    • 维度 1:x

    对于 3D 数组:

    • 维度 0:z
    • 维度 1:y
    • 维度 2:x

    对于 4D 数组:

    • 维度 0:p
    • 维度 1:z
    • 维度 2:y
    • 维度 3:x
  • XLA API 中接受维度的函数会按维度编号的升序接受维度。这与将维度作为 initializer_list 传递时使用的顺序一致;例如

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    将创建一个维度大小数组由序列 [A, B, C, D] 组成的形状。

布局

LayoutProto proto 描述了数组在内存中的表示方式。它包含以下字段:

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

从次要到主要维度的排序

唯一的必填字段是 minor_to_major。此字段描述了形状内维度的次要到主要顺序。minor_to_major 中的值是数组维度的排序(对于 N 维数组,为 0N-1),第一个值是最小维度,最后一个值是最大维度。最小维度是指在遍历以线性内存布局的数组元素时变化最快的维度。

例如,请考虑以下大小为 [2 x 3] 的二维数组:

a b c
d e f

在此示例中,维度 0 的大小为 2,维度 1 的大小为 3。如果布局中的 minor_to_major 字段为 [0, 1],则维度 0 是最次要的维度,而维度 1 是最主要的维度。这对应于线性内存中的以下布局:

a d b e c f

这种从次要到主要的维度顺序(从 0N-1)类似于列优先(对于二维)。假设维度是单调排序的,那么在代码中,我们还可以简单地将此布局称为“dim 0 是次要维度”。

另一方面,如果布局中的 minor_to_major 字段为 [1, 0],则线性内存中的布局为:

a b c d e f

对于 N 维数组,从次要维度到主要维度的维度顺序(即从 N-10)类似于行优先(对于二维数组)。假设维度是单调排序的,那么在代码中,我们还可以简单地将此布局称为“dim 0 is major”。

默认的次要到主要排序方式

新创建的形状的默认布局为“维度顺序是从大到小”(即 [N-1, ..., 0])。

内边距

tail_padding_alignment_in_elements 字段用于定义平铺数组的对齐方式(以元素数量为单位)。应用平铺后,系统会在布局末尾添加填充元素,直到元素总数是此值的倍数。

为数组编制索引

index_util.h 中的类 IndexUtil 提供了一些实用程序,用于在给定形状和布局的情况下,在多维索引和线性索引之间进行转换。多维索引包含每个维度的 int64 索引。线性索引是一个 int64 值,用于索引保存数组的缓冲区。在同一目录中查看 shape_util.hlayout_util.h,了解可简化形状和布局创建及操作的实用程序。

记忆空间标识符

在 HLO 中,每个数组都可以使用内存空间标识符进行注释,写为 S(n)。

  • S(0)(通常省略)表示设备高带宽内存 (HBM)。
  • S(1) 表示设备上的虚拟内存 (VMEM)。
  • S(2)S(3) 等对应于其他特定于设备的内存空间。
  • S(5) 表示主机内存。