XLA 操作的结构
请看一个 HLO 示例:
add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
add(exponential.183, broadcast.3115)
这包括以下组件:
- 操作名称:
add.936- 这是操作的唯一名称。
- 形状:
bf16[8,1,1280,16384]- 这是相应操作的输出形状。此处的 dtype 为 bf16,形状为
[8,1,1280,16384]。
- 这是相应操作的输出形状。此处的 dtype 为 bf16,形状为
- 布局(采用平铺):
3,2,0,1:T(8,128)(2,1)- 此图描述了数组在内存中的存储方式。
3,2,0,1表示内存中轴的顺序(即列优先、行优先等),T(8,128)(2,1)表示所用的平铺和填充。 - 布局是可选的。如果未指定,则不进行平铺,并且假定维度按从最主要到最次要的顺序排列。
- 此图描述了数组在内存中的存储方式。
- 操作:
add- 正在执行的操作。在此处,它是 Add,在操作名称中也有提及。
- 实参:
exponential.183、broadcast.3115- 此操作接受两个实参,并使用其唯一名称指定。
我们再来看一个示例,即融合操作:
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
kind=kCustom, calls=%all-reduce-scatter.3
除了之前描述的组件之外,该架构还包含:
- 属性:
kind和calls- 这些参数可提供有关正在执行的操作(在本例中为融合)的更多信息。
- 内存位置(内存空间标识符):
S(1)- 表示数组的存储内存空间/位置。
S(1)表示此数组位于 VMEM 中(在 TPU 上)。
- 表示数组的存储内存空间/位置。
- 输入实参
%fusion.32的形状和布局详细信息
以下部分将介绍形状、布局和内存空间标识符。如需详细了解平铺,请参阅平铺布局。
形状
XLA ShapeProto proto (xla_data.proto) 描述了 N 维数组(简称数组)的维度数、大小和数据类型。
术语、符号和惯例
数组的真实维度数是指大小大于 1 的维度数。
对于
N维数组,维度的编号从0到N-1。 维度的大小是一个非负整数。特别是,大小为 0 是有效的。维度编号是为方便起见而随意指定的标签。这些维度编号的顺序并不意味着形状布局中存在特定的次要/主要顺序。布局由LayoutProtoproto 确定。按照惯例,维度按维度编号的升序排列。例如,对于大小为
[A x B x C]的 3 维数组,维度 0 的大小为A,维度 1 的大小为B,维度 2 的大小为C。XLA 中的某些实用程序还支持类似 Python 的负索引:维度 -1 是最后一个维度(对于
N维数组,相当于N-1)。例如,对于上述 3 维数组,维度 -1 的大小为C,维度 -2 的大小为B,依此类推。二维、三维和四维数组通常具有与维度关联的特定字母。例如,对于一个二维数组:
- 维度 0:
y - 维度 1:
x
对于 3D 数组:
- 维度 0:
z - 维度 1:
y - 维度 2:
x
对于 4D 数组:
- 维度 0:
p - 维度 1:
z - 维度 2:
y - 维度 3:
x
- 维度 0:
XLA API 中接受维度的函数会按维度编号的升序接受维度。这与将维度作为
initializer_list传递时使用的顺序一致;例如ShapeUtil::MakeShape(F32, {A, B, C, D})将创建一个维度大小数组由序列
[A, B, C, D]组成的形状。
布局
LayoutProto proto 描述了数组在内存中的表示方式。它包含以下字段:
message LayoutProto {
repeated int64 minor_to_major;
int64 tail_padding_alignment_in_elements;
...
}
从次要到主要维度的排序
唯一的必填字段是 minor_to_major。此字段描述了形状内维度的次要到主要顺序。minor_to_major 中的值是数组维度的排序(对于 N 维数组,为 0 到 N-1),第一个值是最小维度,最后一个值是最大维度。最小维度是指在遍历以线性内存布局的数组元素时变化最快的维度。
例如,请考虑以下大小为 [2 x 3] 的二维数组:
a b c
d e f
在此示例中,维度 0 的大小为 2,维度 1 的大小为 3。如果布局中的 minor_to_major 字段为 [0, 1],则维度 0 是最次要的维度,而维度 1 是最主要的维度。这对应于线性内存中的以下布局:
a d b e c f
这种从次要到主要的维度顺序(从 0 到 N-1)类似于列优先(对于二维)。假设维度是单调排序的,那么在代码中,我们还可以简单地将此布局称为“dim 0 是次要维度”。
另一方面,如果布局中的 minor_to_major 字段为 [1, 0],则线性内存中的布局为:
a b c d e f
对于 N 维数组,从次要维度到主要维度的维度顺序(即从 N-1 到 0)类似于行优先(对于二维数组)。假设维度是单调排序的,那么在代码中,我们还可以简单地将此布局称为“dim 0 is major”。
默认的次要到主要排序方式
新创建的形状的默认布局为“维度顺序是从大到小”(即 [N-1, ..., 0])。
内边距
tail_padding_alignment_in_elements 字段用于定义平铺数组的对齐方式(以元素数量为单位)。应用平铺后,系统会在布局末尾添加填充元素,直到元素总数是此值的倍数。
为数组编制索引
index_util.h 中的类 IndexUtil 提供了一些实用程序,用于在给定形状和布局的情况下,在多维索引和线性索引之间进行转换。多维索引包含每个维度的 int64 索引。线性索引是一个 int64 值,用于索引保存数组的缓冲区。在同一目录中查看 shape_util.h 和 layout_util.h,了解可简化形状和布局创建及操作的实用程序。
记忆空间标识符
在 HLO 中,每个数组都可以使用内存空间标识符进行注释,写为 S(n)。
S(0)(通常省略)表示设备高带宽内存 (HBM)。S(1)表示设备上的虚拟内存 (VMEM)。S(2)、S(3)等对应于其他特定于设备的内存空间。S(5)表示主机内存。