广播

本文档介绍了 XLA 的广播语义。

什么是广播?

广播是让不同形状的数组 兼容的形状。该术语源于 NumPy 广播

多维数组之间的操作可能需要广播 或者位于不同但不同值的多维数组之间, 兼容的形状。考虑加法 X+v,其中 X 是矩阵(数组) ,v 是一个矢量(1 阶的数组)。按元素执行 另外,XLA 需要通过“广播”功能将矢量 v 设为与 将 v 复制到矩阵 X 中特定次数。向量的长度 必须至少匹配矩阵的一个维度。

例如:

|1 2 3| + |7 8 9|
|4 5 6|

矩阵的维度为 (2,3),向量的维度为 (3)。矢量 进行广播,方法是对行进行复制,以获取:

|1 2 3| + |7 8 9| = |8  10 12|
|4 5 6|   |7 8 9|   |11 13 15|

在 NumPy 中,这称为 广播

原则

XLA 语言尽可能严格和明确,避免隐式 “神奇”功能。此类特征可能会使一些计算稍微简单一些 但其代价是,要在用户代码中嵌入更多的假设, 也很难改变内含神奇功能(如有必要) 可以在客户端级封装容器中添加

关于广播,XLA 要求明确广播规范 不同秩的数组之间的运算。它与 NumPy 不同 此方法会尽可能推断出具体规范

将一个较低阶的数组广播到一个较高阶的数组上

Scalars始终可通过数组广播,而无需明确指定 多个广播维度标量之间的元素级二元运算 数组是指将带标量的运算应用到 数组。例如,将标量添加到矩阵意味着 其中每个元素都是 输入矩阵。

|1 2 3| + 7 = |8  9  10|
|4 5 6|       |11 12 13|

通过在 二元运算。当运算的输入具有不同的秩时, 广播元组指定 higher-rank 数组中的哪个维度 与优先级较低的数组匹配。

接着前面的例子来讲。不要将标量添加到 (2,3) 矩阵,而是将 将维度 (3) 转换为维度矩阵 (2,3)。未指定 广播,则此操作无效。为了正确请求矩阵矢量 加上,将广播维度指定为 (1),即向量 与矩阵的维度 1 相匹配。在 2D 中,如果维度 0 维度 1 表示行,维度 1 表示列,这表示每个元素 会变成一个列,其大小与 矩阵:

|7 8 9| ==> |7 8 9|
            |7 8 9|

举个更复杂的例子,假设将一个三元素矢量(维度 (3))添加到 3x3 矩阵(维度 (3,3))。可以通过两种方式进行广播 示例:

(1) 可以使用广播维度 1。每个矢量元素都会成为 列并向矩阵中的每一行复制矢量。

|7 8 9| ==> |7 8 9|
            |7 8 9|
            |7 8 9|

(2) 可以使用广播维度 0。每个矢量元素成一行 并向矩阵中的每一列重复该向量。

 |7| ==> |7 7 7|
 |8|     |8 8 8|
 |9|     |9 9 9|

广播维度可以是一个元组,它描述了较小的排名 形状被传播到一个更大的秩形状中。例如,假设有一个 2x3x4 长方体 而 3x4 矩阵,则广播元组 (1,2) 表示将矩阵 长方体的尺寸 1 和 2。

如果XlaBuilder 已指定 broadcast_dimensions 参数。有关示例,请参见 XlaBuilder::Add。 在 XLA 源代码中,这种类型的广播有时称为“InDim” 广播。

正式定义

广播属性允许将排名较低的数组与较高排名的数组进行匹配 数组,方法是指定要匹配的高阶数组的哪些维度。对于 例如,对于维度为 MxNxPxQ 的数组,维度为 T 的向量可以是 匹配如下:

          MxNxPxQ

dim 3:          T
dim 2:        T
dim 1:      T
dim 0:    T

在每种情况下,T 都必须等于更高排名的匹配维度 数组。然后,该向量的值会从匹配的维度广播到 其他维度

为了将 TxV 矩阵与 MxNxPxQ 数组匹配,需要一对广播维度 :

          MxNxPxQ
dim 2,3:      T V
dim 1,2:    T V
dim 0,3:  T     V
etc...

广播元组中维度的顺序必须与 较低阶数组的维度预计会与 更高阶的数组。元组中的第一个元素指定 而排名较高的数组必须与排名较低的数组中的维度 0 匹配。通过 元组中的第二个元素指定更高阶数组中的哪个维度 必须与排名较低的数组中的维度 1 匹配,依此类推。添加 直播维度必须严格增加。例如,在上一个 例如,将 V 与 N 匹配,将 T 与 P 匹配是违法的;匹配 V 值也是非法的, 到 P 和 N。

广播具有退化维度的相似秩数组

一个相关问题是,广播两个具有相同秩序的数组, 不同的尺寸与 NumPy 一样,这只有在数组 兼容。当两个数组的所有维度均为 兼容。在下列情况下,两个维度是兼容的:

  • 它们相等,或
  • 其中一项值为 1(“退化”维度)

当遇到两个兼容的数组时,结果形状的 每个维度索引处的两个输入。

示例:

  1. (2,1) 和 (2,3) 广播到 (2,3)。
  2. (1,2,5) 和 (7,2,5) 广播到 (7,2,5)。
  3. (7,2,5) 和 (7,1,5) 广播到 (7,2,5)。
  4. (7,2,5) 和 (7,2,6) 不兼容,无法广播。

这时会出现一种特殊情况,该情况也受支持,其中每个输入数组都具有 不同索引下的退化维度。在此示例中,结果是 “外部运算”:(2,1) 和 (1,3) 广播到 (2,3)。如需查看更多示例 请查阅 有关广播的 NumPy 文档

广播合成

将较低阶的数组广播到较高阶的数组,广播 使用退化维度的方法都可以在相同的二元运算中执行。 例如,可以将一个大小为 4 的向量和一个大小为 1x2 的矩阵相加 使用值为 (0) 的广播维度:

|1 2 3 4| + [5 6]    // [5 6] is a 1x2 matrix, not a vector.

首先,使用广播将向量广播到秩(矩阵) 维度。广播维度中的单个值 (0) 表示 向量的维度 0 与矩阵的维度 0 匹配。这会生成 大小为 4xM 的矩阵,其中选择值 M 以匹配相应的 尺寸为 1x2 数组。因此,可生成 4x2 矩阵:

|1 1| + [5 6]
|2 2|
|3 3|
|4 4|

然后选择“退化维度广播”广播尺寸为 1x2 的 0 矩阵,以匹配右侧相应维度大小:

|1 1| + |5 6|     |6  7|
|2 2| + |5 6|  =  |7  8|
|3 3| + |5 6|     |8  9|
|4 4| + |5 6|     |9 10|

一个更复杂的示例是将大小为 1x2 的矩阵与大小数组相加 4x3x1,使用 (1, 2) 的广播尺寸。首先,1x2 矩阵向上广播 使用广播维度排名 3,以生成中间 Mx1x2 数组 其中,维度大小 M 由较大操作数( 4x3x1 数组)生成一个 4x1x2 中间数组。M 位于维度 0(即 最左边的维度),因为维度 1 和 2 映射到了维度 原始 1x2 矩阵的权重,因为广播维度为 (1, 2)。这个 中间数组就可以添加到 4x3x1 矩阵, 退化尺寸以生成 4x3x2 数组结果。