本文档介绍了 XLA 的广播语义。
什么是广播?
广播是让不同形状的数组 兼容的形状。该术语源于 NumPy 广播。
多维数组之间的操作可能需要广播
或者位于不同但不同值的多维数组之间,
兼容的形状。考虑加法 X+v
,其中 X
是矩阵(数组)
,v
是一个矢量(1 阶的数组)。按元素执行
另外,XLA 需要通过“广播”功能将矢量 v
设为与
将 v
复制到矩阵 X
中特定次数。向量的长度
必须至少匹配矩阵的一个维度。
例如:
|1 2 3| + |7 8 9|
|4 5 6|
矩阵的维度为 (2,3),向量的维度为 (3)。矢量 进行广播,方法是对行进行复制,以获取:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
在 NumPy 中,这称为 广播。
原则
XLA 语言尽可能严格和明确,避免隐式 “神奇”功能。此类特征可能会使一些计算稍微简单一些 但其代价是,要在用户代码中嵌入更多的假设, 也很难改变内含神奇功能(如有必要) 可以在客户端级封装容器中添加
关于广播,XLA 要求明确广播规范 不同秩的数组之间的运算。它与 NumPy 不同 此方法会尽可能推断出具体规范
将一个较低阶的数组广播到一个较高阶的数组上
Scalars始终可通过数组广播,而无需明确指定 多个广播维度标量之间的元素级二元运算 数组是指将带标量的运算应用到 数组。例如,将标量添加到矩阵意味着 其中每个元素都是 输入矩阵。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
通过在 二元运算。当运算的输入具有不同的秩时, 广播元组指定 higher-rank 数组中的哪个维度 与优先级较低的数组匹配。
接着前面的例子来讲。不要将标量添加到 (2,3) 矩阵,而是将 将维度 (3) 转换为维度矩阵 (2,3)。未指定 广播,则此操作无效。为了正确请求矩阵矢量 加上,将广播维度指定为 (1),即向量 与矩阵的维度 1 相匹配。在 2D 中,如果维度 0 维度 1 表示行,维度 1 表示列,这表示每个元素 会变成一个列,其大小与 矩阵:
|7 8 9| ==> |7 8 9|
|7 8 9|
举个更复杂的例子,假设将一个三元素矢量(维度 (3))添加到 3x3 矩阵(维度 (3,3))。可以通过两种方式进行广播 示例:
(1) 可以使用广播维度 1。每个矢量元素都会成为 列并向矩阵中的每一行复制矢量。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) 可以使用广播维度 0。每个矢量元素成一行 并向矩阵中的每一列重复该向量。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
广播维度可以是一个元组,它描述了较小的排名 形状被传播到一个更大的秩形状中。例如,假设有一个 2x3x4 长方体 而 3x4 矩阵,则广播元组 (1,2) 表示将矩阵 长方体的尺寸 1 和 2。
如果XlaBuilder
已指定 broadcast_dimensions
参数。有关示例,请参见
XlaBuilder::Add。
在 XLA 源代码中,这种类型的广播有时称为“InDim”
广播。
正式定义
广播属性允许将排名较低的数组与较高排名的数组进行匹配 数组,方法是指定要匹配的高阶数组的哪些维度。对于 例如,对于维度为 MxNxPxQ 的数组,维度为 T 的向量可以是 匹配如下:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
在每种情况下,T 都必须等于更高排名的匹配维度 数组。然后,该向量的值会从匹配的维度广播到 其他维度
为了将 TxV 矩阵与 MxNxPxQ 数组匹配,需要一对广播维度 :
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
广播元组中维度的顺序必须与 较低阶数组的维度预计会与 更高阶的数组。元组中的第一个元素指定 而排名较高的数组必须与排名较低的数组中的维度 0 匹配。通过 元组中的第二个元素指定更高阶数组中的哪个维度 必须与排名较低的数组中的维度 1 匹配,依此类推。添加 直播维度必须严格增加。例如,在上一个 例如,将 V 与 N 匹配,将 T 与 P 匹配是违法的;匹配 V 值也是非法的, 到 P 和 N。
广播具有退化维度的相似秩数组
一个相关问题是,广播两个具有相同秩序的数组, 不同的尺寸与 NumPy 一样,这只有在数组 兼容。当两个数组的所有维度均为 兼容。在下列情况下,两个维度是兼容的:
- 它们相等,或
- 其中一项值为 1(“退化”维度)
当遇到两个兼容的数组时,结果形状的 每个维度索引处的两个输入。
示例:
- (2,1) 和 (2,3) 广播到 (2,3)。
- (1,2,5) 和 (7,2,5) 广播到 (7,2,5)。
- (7,2,5) 和 (7,1,5) 广播到 (7,2,5)。
- (7,2,5) 和 (7,2,6) 不兼容,无法广播。
这时会出现一种特殊情况,该情况也受支持,其中每个输入数组都具有 不同索引下的退化维度。在此示例中,结果是 “外部运算”:(2,1) 和 (1,3) 广播到 (2,3)。如需查看更多示例 请查阅 有关广播的 NumPy 文档。
广播合成
将较低阶的数组广播到较高阶的数组,并广播 使用退化维度的方法都可以在相同的二元运算中执行。 例如,可以将一个大小为 4 的向量和一个大小为 1x2 的矩阵相加 使用值为 (0) 的广播维度:
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
首先,使用广播将向量广播到秩(矩阵) 维度。广播维度中的单个值 (0) 表示 向量的维度 0 与矩阵的维度 0 匹配。这会生成 大小为 4xM 的矩阵,其中选择值 M 以匹配相应的 尺寸为 1x2 数组。因此,可生成 4x2 矩阵:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
然后选择“退化维度广播”广播尺寸为 1x2 的 0 矩阵,以匹配右侧相应维度大小:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
一个更复杂的示例是将大小为 1x2 的矩阵与大小数组相加 4x3x1,使用 (1, 2) 的广播尺寸。首先,1x2 矩阵向上广播 使用广播维度排名 3,以生成中间 Mx1x2 数组 其中,维度大小 M 由较大操作数( 4x3x1 数组)生成一个 4x1x2 中间数组。M 位于维度 0(即 最左边的维度),因为维度 1 和 2 映射到了维度 原始 1x2 矩阵的权重,因为广播维度为 (1, 2)。这个 中间数组就可以添加到 4x3x1 矩阵, 退化尺寸以生成 4x3x2 数组结果。