正在广播

本文档介绍了 XLA 的广播语义。

什么是广播?

广播是使具有不同形状的数组具有兼容的形状以进行算术运算的过程。该术语借鉴了 NumPy 广播

对于不同秩的多维数组之间的操作或具有不同但兼容的形状的多维数组之间的操作,可能需要广播。以加法 X+v 为例,其中 X 是矩阵(秩为 2 的数组),v 是一个向量(秩为 1 的数组)。如需执行元素级添加,XLA 需要通过复制 v 一定次数,将矢量 v“广播”到与矩阵 X 相同的排名。向量的长度必须与矩阵的至少一个维度匹配。

例如:

|1 2 3| + |7 8 9|
|4 5 6|

矩阵的维度为 (2,3),向量的维度为 (3)。传播时将该矢量复制到行上,以获取:

|1 2 3| + |7 8 9| = |8  10 12|
|4 5 6|   |7 8 9|   |11 13 15|

在 NumPy 中,这称为广播

原则

XLA 语言尽可能严格和显式,避免隐式的“神奇”功能。此类特征可能会让一些计算稍微容易定义,但代价是用户代码中融入了更多假设,而从长远来看,这些假设将难以更改。如有必要,可以在客户端级封装容器中添加隐式神奇功能。

关于广播,XLA 要求对不同秩的数组之间的操作有明确的广播规范。这与 NumPy 不同,NumPy 会在可能的情况下推断规范。

将一个低阶数组广播到一个高阶数组

标量始终可以通过数组广播,而无需明确指定广播维度。标量和数组之间的元素级二元运算表示将带有标量的运算应用于数组中的每个元素。例如,向矩阵添加标量意味着生成一个矩阵,其中每个元素都是该标量和输入矩阵相应元素的总和。

|1 2 3| + 7 = |8  9  10|
|4 5 6|       |11 12 13|

大多数广播需求都可以通过在二元运算中使用维度元组来捕获。当操作的输入具有不同的排名时,此广播元组会指定 high-rank 数组中的哪些维度与 bottom-rank 数组匹配的维度。

接着前面的例子来讲。向维度矩阵 (2,3) 添加一个维度 (3) 的向量,而不是向 (2,3) 矩阵添加标量。如果未指定广播,则此操作无效。若要正确请求矩阵向量加法,请将广播维度指定为 (1),这意味着向量的维度与矩阵的维度 1 匹配。在 2D 中,如果维度 0 表示行,维度 1 表示列,这意味着矢量的每个元素都会成为大小与矩阵中的行数相匹配的列:

|7 8 9| ==> |7 8 9|
            |7 8 9|

举一个更复杂的示例,考虑向 3x3 矩阵(维度 (3,3))添加一个 3 元素向量(维度 (3))。对于此示例,广播有两种方式:

(1) 可以使用 1 的广播维度。每个向量元素都会成为一列,并且向矩阵中的每一行复制该向量。

|7 8 9| ==> |7 8 9|
            |7 8 9|
            |7 8 9|

(2) 可以使用 0 的广播维度。每个向量元素成为一行,并且向矩阵中的每一列复制向量。

 |7| ==> |7 7 7|
 |8|     |8 8 8|
 |9|     |9 9 9|

广播维度可以是元组,描述较小的排名形状如何广播为较大的排名形状。例如,假设有一个 2x3x4 长方体和一个 3x4 矩阵,广播元组 (1,2) 表示将矩阵与长方体的维度 1 和维度 2 相匹配。

如果指定了 broadcast_dimensions 参数,则会在 XlaBuilder 的二进制操作中使用此类广播。例如,请参阅 XlaBuilder::Add。在 XLA 源代码中,此类广播有时称为“InDim”广播。

正式定义

广播属性允许通过指定要匹配较高排名数组的维度来将较低排名的数组与较高排名的数组进行匹配。例如,对于维度 MxNxPxQ 的数组,维度 T 的向量可以按如下方式匹配:

          MxNxPxQ

dim 3:          T
dim 2:        T
dim 1:      T
dim 0:    T

在每种情况下,T 都必须等于排名较高的数组的匹配维度。然后,向量的值会从匹配的维度广播到所有其他维度。

为了将 TxV 矩阵与 MxNxPxQ 数组匹配,使用了一对广播维度:

          MxNxPxQ
dim 2,3:      T V
dim 1,2:    T V
dim 0,3:  T     V
etc...

广播元组中的维度顺序必须为低阶数组的维度与较高阶数组的维度匹配的顺序。元组中的第一个元素指定排名较高的数组中的哪个维度必须与较低排名数组中的维度 0 匹配。元组中的第二个元素指定排名较高的数组中的哪个维度必须与较低排名数组中的维度 1 匹配,依此类推。广播维度的顺序必须严格递增。例如,在前面的示例中,将 V 与 N 匹配,T 与 P 匹配是非法的;将 V 同时匹配到 P 和 N 也是非法的。

广播具有退化维度的相似秩数组

相关问题是广播两个具有相同排名但维度大小不同的数组。与 NumPy 一样,只有当数组兼容时,才可以这么做。如果两个数组的所有维度都兼容,则它们是兼容的。在以下情况下,两个维度可以兼容:

  • 它们相等,或者
  • 其中一个是 1(“退化”维度)

当遇到两个兼容的数组时,结果形状在每个维度索引处的两个输入中均为最大值。

示例:

  1. (2,1) 和 (2,3) 广播到 (2,3)。
  2. (1,2,5) 和 (7,2,5) 广播到 (7,2,5)。
  3. (7,2,5) 和 (7,1,5) 广播到 (7,2,5)。
  4. (7,2,5) 和 (7,2,6) 不兼容,无法广播。

出现一种特殊情况,也支持该特殊情况,其中每个输入数组都有一个不同索引的退化维度。在这种情况下,结果是“外部运算”:(2,1) 和 (1,3) 广播到 (2,3)。如需查看更多示例,请参阅有关广播的 NumPy 文档

广播合成

向高阶数组广播低阶数组和使用退化维度广播可以在同一个二元运算中执行。例如,可以使用值为 (0) 的广播维度将大小为 4 的向量和大小为 1x2 的矩阵相加:

|1 2 3 4| + [5 6]    // [5 6] is a 1x2 matrix, not a vector.

首先,使用广播维度广播向量,最高为 2 阶(矩阵)。广播维度中的单个值 (0) 表示向量的维度零与矩阵的维度零匹配。这会生成一个大小为 4xM 的矩阵,其中选择值 M 来匹配 1x2 数组中的相应维度大小。因此,生成一个 4x2 矩阵:

|1 1| + [5 6]
|2 2|
|3 3|
|4 4|

然后,“退化维度广播”广播 1x2 矩阵的维度零,以匹配右侧对应的维度大小:

|1 1| + |5 6|     |6  7|
|2 2| + |5 6|  =  |7  8|
|3 3| + |5 6|     |8  9|
|4 4| + |5 6|     |9 10|

更复杂的示例是,使用 (1, 2) 广播维度将大小为 1x2 的矩阵添加到大小为 4x3x1 的数组中。首先,使用广播维度广播 1x2 矩阵第 3 阶,以生成一个中间 Mx1x2 数组,其中维度大小 M 由较大的运算数(4x3x1 数组)的大小决定,从而生成一个 4x1x2 中间数组。M 位于维度 0(最左侧的维度),因为维度 1 和 2 映射到原始 1x2 矩阵的维度,因为广播维度为 (1, 2)。可以通过广播退化维度将此中间数组添加到 4x3x1 矩阵中,以生成 4x3x2 数组结果。