正在广播

本文档介绍了 XLA 的广播语义。

什么是广播?

广播是使具有不同形状的数组具有兼容的形状以进行算术运算的过程。该术语借鉴了 NumPy 广播

对于不同秩的多维数组之间或具有不同但兼容的多维数组之间的操作,可能需要广播。以添加 X+v 为例,其中 X 是一个矩阵(2 阶数组),v 是一个向量(1 阶数组)。为了执行元素级加法,XLA 需要通过复制 v 一定次数,将矢量 v“广播”到与矩阵 X 相同的秩。向量的长度必须与矩阵的至少一个维度匹配。

例如:

|1 2 3| + |7 8 9|
|4 5 6|

矩阵的维度为 (2,3),向量的维度为 (3)。向量通过对行进行复制来广播,以获取:

|1 2 3| + |7 8 9| = |8  10 12|
|4 5 6|   |7 8 9|   |11 13 15|

在 NumPy 中,这称为广播

原则

XLA 语言尽可能严格且明确,避免隐式的“神奇”功能。此类功能可能会使一些计算更易于定义,但以用户代码中融入更多假设为代价,这些假设从长期来看很难更改。如有必要,可以在客户端级封装容器中添加隐式神奇功能。

对于广播,XLA 要求对不同秩的数组之间的操作有明确的广播规范。它与 NumPy 不同,它在可能的情况下推断规范。

将低阶数组广播到高阶数组

标量始终可以在数组上广播,而无需明确指定广播维度。标量和数组之间的元素级二元运算意味着将带有标量的运算应用于数组中的每个元素。例如,向矩阵添加标量意味着生成一个矩阵,其中每个元素都是输入矩阵标量和相应元素的总和。

|1 2 3| + 7 = |8  9  10|
|4 5 6|       |11 12 13|

通过在二元运算中使用维度元组来捕获大多数广播需求。当操作的输入具有不同的秩时,此广播元组会指定 high-rank 数组中的哪些维度与 lower-rank 数组匹配。

接着前面的例子来讲。将维度 (3) 的向量添加到维度矩阵 (2,3),而不是向 (2,3) 矩阵添加标量。如果未指定广播,则此操作无效。若要正确请求矩阵向量加法,请将广播维度指定为 (1),这意味着向量的维度与矩阵的维度 1 匹配。在 2D 中,如果维度 0 表示行,维度 1 表示列,这意味着向量的每个元素都变成了大小与矩阵中的行数相匹配的列:

|7 8 9| ==> |7 8 9|
            |7 8 9|

举一个更复杂的例子,可以考虑将一个 3 元素向量(维度 (3))添加到 3x3 矩阵(维度 (3,3))。在此示例中,可以通过两种方法进行广播:

(1) 可以使用 1 的广播维度。每个向量元素都会成为列,并且向量会针对矩阵中的每一行进行重复。

|7 8 9| ==> |7 8 9|
            |7 8 9|
            |7 8 9|

(2) 可以使用广播维度 0。每个向量元素成为一行,向矩阵中的每一列复制该向量。

 |7| ==> |7 7 7|
 |8|     |8 8 8|
 |9|     |9 9 9|

广播维度可以是元组,用于描述较小的排名形状如何广播为较大的排名形状。例如,给定一个 2x3x4 长方体和一个 3x4 矩阵,广播元组 (1,2) 表示将矩阵与长方体的维度 1 和 2 相匹配。

如果提供了 broadcast_dimensions 参数,则会在 XlaBuilder 的二进制操作中使用此类广播。有关示例,请参阅 XlaBuilder::Add。在 XLA 源代码中,这种类型的广播有时称为“InDim”广播。

正式定义

广播属性允许通过指定要匹配较高秩数组的哪些维度来将低阶数组与高阶数组进行匹配。例如,对于维度为 MxNxPxQ 的数组,维度为 T 的向量可以按如下方式匹配:

          MxNxPxQ

dim 3:          T
dim 2:        T
dim 1:      T
dim 0:    T

在每种情况下,T 都必须等于较高秩数组的匹配维度。然后,向量的值会从匹配的维度广播到所有其他维度。

为了将 TxV 矩阵与 MxNxPxQ 数组匹配,需要使用一对广播维度:

          MxNxPxQ
dim 2,3:      T V
dim 1,2:    T V
dim 0,3:  T     V
etc...

广播元组中维度的顺序必须为低阶数组的维度与较高阶数组的维度匹配的顺序。元组中的第一个元素指定高阶数组中的哪个维度必须与低阶数组中的维度 0 匹配。元组中的第二个元素指定高阶数组中的哪个维度必须与低阶数组中的维度 1 匹配,依此类推。广播维度的顺序必须严格递增。例如,在前面的示例中,将 V 与 N 和 T 匹配 P 是非法的;同时将 V 与 P 和 N 匹配也是非法的。

广播具有退化维度的相似秩数组

一个相关问题是广播两个具有相同秩但维度大小不同的数组。与 NumPy 一样,只有当数组兼容时,才可以这么做。如果两个数组的所有维度都兼容,则它们是兼容的。在以下情况下,两个维度可以兼容:

  • 它们相等,或者
  • 其中一个是 1(“退化”维度)

当遇到两个兼容的数组时,结果形状在每个维度索引处最多有两个输入。

示例:

  1. (2,1) 和 (2,3) 广播到 (2,3)。
  2. (1,2,5) 和 (7,2,5) 广播到 (7,2,5)。
  3. (7,2,5) 和 (7,1,5) 广播到 (7,2,5)。
  4. (7,2,5) 和 (7,2,6) 不兼容,无法广播。

一种特殊情况,也受支持,其中每个输入数组都有一个位于不同索引的退化维度。在本例中,结果是“外部运算”:(2,1) 和 (1,3) 广播到 (2,3)。如需查看更多示例,请参阅有关广播的 NumPy 文档

广播组合

向高阶数组广播低阶数组和使用退化维度进行广播都可以在同一二元运算中执行。例如,可以使用值为 (0) 的广播维度将大小为 4 的向量和大小为 1x2 的矩阵相加:

|1 2 3 4| + [5 6]    // [5 6] is a 1x2 matrix, not a vector.

首先,使用广播维度广播矢量,最高为 2 阶(矩阵)。广播维度中的单个值 (0) 表示向量的维度 0 与矩阵的维度 0 匹配。这会生成一个大小为 4xM 的矩阵,其中值 M 选择与 1x2 数组中的相应维度大小相匹配。因此会生成一个 4x2 矩阵:

|1 1| + [5 6]
|2 2|
|3 3|
|4 4|

然后,“退化维度广播”广播 1x2 矩阵的维度 0,以匹配右侧对应的维度大小:

|1 1| + |5 6|     |6  7|
|2 2| + |5 6|  =  |7  8|
|3 3| + |5 6|     |8  9|
|4 4| + |5 6|     |9 10|

一个更复杂的示例是,使用 (1, 2) 的广播维度将大小为 1x2 的矩阵添加到大小为 4x3x1 的数组中。首先,使用广播维度广播 1x2 矩阵第 3 阶,以生成一个中间 Mx1x2 数组,其中维度大小 M 由产生 4x1x2 中间数组的较大运算数(4x3x1 数组)的大小决定。M 处于维度 0(最左边的维度),因为维度 1 和 2 映射到原始 1x2 矩阵的维度,因为广播维度为 (1, 2)。可以通过广播退化维度来将此中间数组添加到 4x3x1 矩阵中,以生成 4x3x2 数组结果。