本文档介绍了 XLA 的广播语义。
什么是广播?
广播是指使形状不同的数组具有兼容的形状,以便进行算术运算的过程。此术语借鉴自 NumPy 广播。
对于不同秩的多维数组之间或具有不同但兼容形状的多维数组之间的运算,可能需要广播。考虑加法 X+v,其中 X 是矩阵(具有 2 个维度的数组),v 是向量(具有 1 个维度的数组)。为了执行按元素相加,XLA 需要通过将向量 v 复制一定次数,将其“广播”到与矩阵 X 相同的维度数。v向量的长度必须与矩阵的至少一个维度相匹配。
例如:
|1 2 3| + |7 8 9|
|4 5 6|
矩阵的维度为 (2,3),向量的维度为 (3)。通过在行中复制向量来广播向量,得到:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
在 NumPy 中,这称为广播。
原则
XLA 语言尽可能严格和明确,避免使用隐式的“神奇”功能。此类功能可能会使某些计算的定义稍微容易一些,但代价是用户代码中会包含更多假设,从长远来看,这些假设很难更改。如有必要,可以在客户端级封装容器中添加隐式神奇功能。
对于广播,XLA 要求对不同秩的数组之间的运算进行显式广播规范。这与 NumPy 不同,后者会在可能的情况下推断规范。
将低维数组广播到高维数组
标量始终可以广播到数组,而无需明确指定广播维度。标量与数组之间的逐元素二元运算是指将标量与数组中的每个元素进行运算。例如,将标量添加到矩阵意味着生成一个矩阵,其中每个元素都是标量与输入矩阵的相应元素的和。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
大多数广播需求都可以通过对二元运算使用维度元组来满足。当运算的输入具有不同的秩时,此广播元组会指定将较高维度数组中的哪些维度与较低维度数组相匹配。
接着前面的例子来讲。将维度为 (3) 的向量添加到维度为 (2,3) 的矩阵,而不是将标量添加到维度为 (2,3) 的矩阵。如果不指定广播,此操作无效。若要正确请求矩阵-向量加法,请将广播维度指定为 (1),这意味着向量的维度与矩阵的维度 1 相匹配。在二维中,如果维度 0 表示行,维度 1 表示列,这意味着向量的每个元素都会成为大小与矩阵行数相匹配的列:
|7 8 9| ==> |7 8 9|
|7 8 9|
再举一个更复杂的例子,假设要将一个 3 元素向量(维度为 (3))添加到一个 3x3 矩阵(维度为 (3,3))。在此示例中,广播可以通过以下两种方式进行:
(1) 可以使用广播维度为 1 的张量。每个向量元素都会成为一列,并且该向量会针对矩阵中的每一行进行复制。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) 可以使用广播维度 0。每个向量元素都成为一行,并且该向量会针对矩阵中的每一列进行复制。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
广播维度可以是一个元组,用于描述如何将较低维度的形状广播到较高维度的形状中。例如,给定一个 2x3x4 的长方体和一个 3x4 的矩阵,广播元组 (1,2) 表示将矩阵与长方体的维度 1 和 2 相匹配。
如果提供了 broadcast_dimensions 实参,则 XlaBuilder 中的二元运算会使用这种广播。例如,请参阅 XlaBuilder::Add。
在 XLA 源代码中,这种类型的广播有时称为“InDim”广播。
正式定义
通过广播属性,您可以指定要匹配的高维数组的哪些维度,从而将低维数组与高维数组进行匹配。例如,对于维度为 MxNxPxQ 的数组,维度为 T 的向量可以按如下方式进行匹配:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
在每种情况下,T 都必须等于高维数组的匹配维度。然后,向量的值会从匹配的维度广播到所有其他维度。
为了将 TxV 矩阵与 MxNxPxQ 数组匹配,使用了一对广播维度:
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
广播元组中的维度顺序必须与低维数组的维度与高维数组的维度相匹配的预期顺序一致。元组中的第一个元素指定了高维数组中的哪个维度必须与低维数组中的维度 0 相匹配。元组中的第二个元素指定了高维数组中的哪个维度必须与低维数组中的维度 1 相匹配,依此类推。广播维度的顺序必须严格递增。例如,在前面的示例中,将 V 与 N 和 T 与 P 相匹配是违规的;将 V 同时与 P 和 N 相匹配也是违规的。
广播具有退化维度的相似维度数组
一个相关的问题是广播两个维度数相同但维度大小不同的数组。与 NumPy 一样,只有当数组兼容时,才能进行这种操作。如果两个数组的所有维度都兼容,则这两个数组兼容。如果满足以下条件,则两个维度兼容:
- 它们相等,或者
- 其中一个为 1(“退化”维度)
如果遇到两个兼容的数组,结果形状在每个维度索引处都具有两个输入的最大值。
示例:
- (2,1) 和 (2,3) 广播到 (2,3)。
- (1,2,5) 和 (7,2,5) 广播到 (7,2,5)。
- (7,2,5) 和 (7,1,5) 广播到 (7,2,5)。
- (7,2,5) 和 (7,2,6) 不兼容,无法广播。
还有一种特殊情况(也受支持),即每个输入数组在不同的索引处都有一个退化维度。在这种情况下,结果是“外部运算”:(2,1) 和 (1,3) 广播到 (2,3)。如需查看更多示例,请参阅 NumPy 广播文档。
广播内容组成
将低维数组广播到高维数组和使用退化维度的广播都可以在同一二元运算中执行。例如,可以使用广播维度值 (0) 将大小为 4 的向量和大小为 1x2 的矩阵相加:
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
首先,使用广播维度将向量广播到 2 维(矩阵)。广播维度中的单个值 (0) 表示向量的维度 0 与矩阵的维度 0 相匹配。这会生成一个大小为 4xM 的矩阵,其中 M 的值会选择为与 1x2 数组中的相应维度大小相匹配。因此,系统会生成一个 4x2 矩阵:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
然后,“退化维度广播”会广播 1x2 矩阵的维度 0,以匹配右侧的相应维度大小:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
一个更复杂的示例是,使用广播维度 (1, 2) 将大小为 1x2 的矩阵添加到大小为 4x3x1 的数组。首先,使用广播维度将 1x2 矩阵广播到 3 维,以生成中间 Mx1x2 数组,其中维度大小 M 由较大操作数(即 4x3x1 数组)的大小决定,从而生成 4x1x2 中间数组。M 位于维度 0(最左侧的维度),因为维度 1 和 2 映射到原始 1x2 矩阵的维度,广播维度为 (1, 2)。可以使用简并维度的广播将此中间数组添加到 4x3x1 矩阵,以生成 4x3x2 数组结果。