下面介绍了 XlaBuilder
接口中定义的操作的语义。通常,这些操作会与 xla_data.proto
中的 RPC 接口中定义的操作一对一映射。
关于命名法的小注:XLA 处理的泛化数据类型是一个 N 维数组,其中包含某种统一类型(例如 32 位浮点数)的元素。在本文档中,数组用于表示任意维数的数组。为方便起见,特殊情况具有更具体的熟悉名称;例如,矢量是 1 维数组,矩阵是 2 维数组。
AfterAll
另请参阅 XlaBuilder::AfterAll
。
AfterAll 接受可变数量的令牌,并生成单个令牌。令牌是基元类型,可在有副作用的操作之间串行化以强制执行排序。AfterAll
可用作令牌联接,用于在集合操作之后对操作进行排序。
AfterAll(operands)
参数 | 类型 | 语义 |
---|---|---|
operands |
XlaOp |
可变数量的令牌 |
AllGather
另请参阅 XlaBuilder::AllGather
。
跨副本执行串联。
AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
channel_id)
参数 | 类型 | 语义 |
---|---|---|
operand
|
XlaOp
|
要在各个副本之间串联的数组 |
all_gather_dim |
int64 |
串联维度 |
replica_groups
|
int64 向量的向量 |
在哪些组之间执行串联 |
channel_id
|
可选 int64
|
用于跨模块通信的可选通道 ID |
replica_groups
是用于执行串联的副本组的列表(可以使用ReplicaId
检索当前副本的副本 ID)。每个组中的副本顺序决定了其输入在结果中的顺序。replica_groups
必须为空(在这种情况下,所有副本都属于一个组,按0
到N - 1
的顺序排列),或者包含与副本数量相同的元素。例如,replica_groups = {0, 2}, {1, 3}
会在副本0
和2
以及1
和3
之间执行串联操作。shard_count
是每个副本组的大小。在replica_groups
为空的情况下,我们需要这样做。channel_id
用于跨模块通信:只有具有相同channel_id
的all-gather
操作才能相互通信。
输出形状是输入形状,其中 all_gather_dim
扩大了 shard_count
倍。例如,如果有两个副本,并且两个副本上的操作数的值为 [1.0, 2.5]
和 [3.0, 5.25]
,则此操作(all_gather_dim
为 0
)的输出值在两个副本上都将为 [1.0, 2.5, 3.0,
5.25]
。
AllReduce
另请参阅 XlaBuilder::AllReduce
。
跨副本执行自定义计算。
AllReduce(operand, computation, replica_group_ids, channel_id)
参数 | 类型 | 语义 |
---|---|---|
operand
|
XlaOp
|
要跨副本减少的数组或非空数组元组 |
computation |
XlaComputation |
归约计算 |
replica_groups
|
int64 的矢量矢量 |
执行缩减操作的组 |
channel_id
|
可选 int64
|
用于跨模块通信的可选通道 ID |
- 当
operand
是数组元组时,会对元组的每个元素执行 all-reduce。 replica_groups
是执行求和的副本组列表(可以使用ReplicaId
检索当前副本的副本 ID)。replica_groups
必须为空(在这种情况下,所有副本都属于单个组),或者包含与副本数量相同的元素数量。例如,replica_groups = {0, 2}, {1, 3}
会在副本0
和2
以及1
和3
之间执行求值。channel_id
用于跨模块通信:只有具有相同channel_id
的all-reduce
操作才能相互通信。
输出形状与输入形状相同。例如,如果有两个副本,并且运算数在两个副本上的值分别为 [1.0, 2.5]
和 [3.0, 5.25]
,则这两个副本中此运算和求和计算的输出值均为 [4.0, 7.75]
。如果输入是元组,则输出也是元组。
计算 AllReduce
的结果需要从每个副本获取一个输入,因此,如果一个副本执行 AllReduce
节点的次数多于另一个副本,则前一个副本将永远等待。由于所有副本都运行相同的程序,因此发生这种情况的方式并不多,但如果 while 循环的条件依赖于 infeed 中的数据,并且 infeed 中的数据导致 while 循环在一个副本上迭代次数比另一个副本多,就有可能发生这种情况。
AllToAll
另请参阅 XlaBuilder::AllToAll
。
AllToAll 是一种集合操作,用于将数据从所有核心发送到所有核心。该方法分两个阶段来实现:
- 散射阶段。在每个核心上,操作数会沿着
split_dimensions
拆分为数量为split_count
的块,并且这些块分散到所有核心,例如,第 i 个块发送到第 i 个核心。 - 收集阶段。每个核心都会沿
concat_dimension
串联收到的块。
参与的核心可以通过以下方式进行配置:
replica_groups
:每个 ReplicaGroup 都包含参与计算的副本 ID 的列表(可以使用ReplicaId
检索当前副本的副本 ID)。AllToAll 将按指定顺序在子组内应用。例如,replica_groups = { {1,2,3}, {4,5,0} }
表示将在副本{1, 2, 3}
内以及在收集阶段应用 AllToAll,并且收到的块将按 1、2、3 的相同顺序串联。然后,系统会在副本 4、5、0 中应用另一个 AllToAll,并且串联顺序也是 4、5、0。如果replica_groups
为空,则所有副本都属于一个组,按其出现的串联顺序排列。
前提条件:
split_dimension
上的运算数的维度大小可被split_count
整除。- 操作数的形状不是元组。
AllToAll(operand, split_dimension, concat_dimension, split_count,
replica_groups)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
维输入数组 |
split_dimension
|
int64
|
一个介于 [0,
n) 之间的值,用于命名操作数沿其拆分的数据维度 |
concat_dimension
|
int64
|
一个介于 [0,
n) 之间的值,用于指定沿哪个维度串联分块 |
split_count
|
int64
|
参与此操作的核心数。如果 replica_groups 为空,则此值应为副本数量;否则,此值应等于每个组中的副本数量。 |
replica_groups
|
ReplicaGroup 矢量
|
每个组都包含一个副本 ID 列表。 |
以下是 Alltoall 的示例。
XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
在此示例中,有 4 个核心参与 Alltoall。在每个核心上,操作数会沿着维度 0 拆分为 4 个部分,因此每个部分的形状为 f32[4,4]。4 个部分会分散到所有核心。然后,每个核心按照核心 0-4 的顺序沿维度 1 串联接收的部分。因此,每个核心上的输出形状为 f32[16,4]。
BatchNormGrad
有关该算法的详细说明,另请参阅 XlaBuilder::BatchNormGrad
和原始批量归一化白皮书。
计算批处理正则化的梯度。
BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要归一化的 n 维数组 (x) |
scale |
XlaOp |
一维数组 (\(\gamma\)) |
mean |
XlaOp |
一维数组 (\(\mu\)) |
variance |
XlaOp |
一维数组 (\(\sigma^2\)) |
grad_output |
XlaOp |
传递给 BatchNormTraining (\(\nabla y\)) 的渐变 |
epsilon |
float |
Epsilon 值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中特征维度的索引 |
对于特征维度中的每个特征(feature_index
是 operand
中特征维度的索引),该操作会计算相对于所有其他维度的 operand
、offset
和 scale
的梯度。feature_index
必须是 operand
中特征维度的有效索引。
这三个梯度由以下公式定义(假设 4 维数组为 operand
,特征维度索引为 l
,批处理大小为 m
,空间大小为 w
和 h
):
\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]
输入 mean
和 variance
表示跨批量和空间维度的时刻值。
输出类型是三个句柄的元组:
输出 | 类型 | 语义 |
---|---|---|
grad_operand
|
XlaOp
|
相对于输入 operand ($\nabla x$) 的梯度 |
grad_scale
|
XlaOp
|
相对于输入 scale 的梯度 ($\nabla
\gamma$) |
grad_offset
|
XlaOp
|
相对于输入 offset 的梯度 ($\nabla
\beta$) |
BatchNormInference
有关该算法的详细说明,另请参阅 XlaBuilder::BatchNormInference
和原始批量归一化白皮书。
对批量和空间维度中的数组进行归一化。
BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要归一化的 n 维数组 |
scale |
XlaOp |
一维数组 |
offset |
XlaOp |
一维数组 |
mean |
XlaOp |
一维数组 |
variance |
XlaOp |
1 维数组 |
epsilon |
float |
ε 值 |
feature_index |
int64 |
operand 中特征维度的索引 |
对于特征维度中的每个特征(feature_index
是 operand
中特征维度的索引),该操作会计算所有其他维度的均值和方差,并使用均值和方差对 operand
中的每个元素进行归一化处理。feature_index
必须是 operand
中地图项维度的有效索引。
BatchNormInference
等同于调用 BatchNormTraining
,而无需为每个批处理计算 mean
和 variance
。它会改用输入 mean
和 variance
作为估算值。此运算的目的是减少推理延迟时间,因此得名为 BatchNormInference
。
输出是一个 n 维归一化数组,其形状与输入 operand
相同。
BatchNormTraining
有关该算法的详细说明,另请参阅 XlaBuilder::BatchNormTraining
和 the original batch normalization paper
。
对批量和空间维度中的数组进行归一化。
BatchNormTraining(operand, scale, offset, epsilon, feature_index)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要归一化的 n 维数组 (x) |
scale |
XlaOp |
一维数组 (\(\gamma\)) |
offset |
XlaOp |
一维数组 (\(\beta\)) |
epsilon |
float |
Epsilon 值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中特征维度的索引 |
对于特征维度中的每个特征(feature_index
是 operand
中特征维度的索引),该操作会计算所有其他维度的均值和方差,并使用均值和方差对 operand
中的每个元素进行归一化处理。feature_index
必须是 operand
中地图项维度的有效索引。
对于 operand
\(x\) 中的每个批次,如果该批次包含 m
元素,且空间维度的大小为 w
和 h
(假设 operand
是一个 4 维数组),则算法如下所示:
计算特征维度中每个特征
l
的批处理均值 \(\mu_l\) :\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)计算批量方差 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$
标准化、缩放和移位:\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)
添加 epsilon 值(通常为小数)是为了避免除零错误。
输出类型是一个包含三个 XlaOp
的元组:
输出 | 类型 | 语义 |
---|---|---|
output
|
XlaOp
|
与输入 operand (y) 具有相同形状的 n 维数组 |
batch_mean |
XlaOp |
一维数组 (\(\mu\)) |
batch_var |
XlaOp |
一维数组 (\(\sigma^2\)) |
batch_mean
和 batch_var
是使用上述公式在批处理和空间维度上计算的瞬时值。
BitcastConvertType
另请参阅 XlaBuilder::BitcastConvertType
。
与 TensorFlow 中的 tf.bitcast
类似,执行从数据形状到目标形状的逐元素按位转换运算。输入和输出大小必须匹配:例如,s32
元素通过 Bitcast 例程变为 f32
元素,而一个 s32
元素将变为四个 s8
元素。位转换是作为低级转换实现的,因此具有不同浮点表示形式的机器会给出不同的结果。
BitcastConvertType(operand, new_element_type)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 且维度为 D 的数组 |
new_element_type |
PrimitiveType |
类型 U |
除了最后一个维度(将按转换前后基元大小的比率而变化)外,运算数和目标形状的维度必须匹配。
源和目标元素类型不得为元组。
对不同宽度的基元类型进行位转换
BitcastConvert
HLO 指令支持输出元素类型 T'
的大小不等于输入元素 T
的大小的情况。由于从概念上来讲,整个操作是 bitcast,并且不会更改底层的字节,因此必须更改输出元素的形状。对于 B = sizeof(T), B' =
sizeof(T')
,可能有两种情况。
首先,当 B > B'
时,输出形状会获得大小为 B/B'
的新最小次元。例如:
f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)
对于有效标量,规则保持不变:
f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)
或者,对于 B' > B
,该指令要求输入形状的最后一个逻辑维度等于 B'/B
,并且此维度会在转换期间被舍弃:
f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)
请注意,不同位宽之间的转换不是按元素进行的。
广播
另请参阅 XlaBuilder::Broadcast
。
通过复制数组中的数据,向数组添加维度。
Broadcast(operand, broadcast_sizes)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要复制的数组 |
broadcast_sizes |
ArraySlice<int64> |
新维度的大小 |
新维度会插入到左侧,即如果 broadcast_sizes
的值为 {a0, ..., aN}
,并且操作数形状的维度为 {b0, ..., bM}
,则输出的形状的维度为 {a0, ..., aN, b0, ..., bM}
。
新的维度会对运算数的副本进行编制索引,即
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
例如,如果 operand
是值为 2.0f
的标量 f32
,并且 broadcast_sizes
为 {2, 3}
,则结果将是形状为 f32[2, 3]
的数组,并且结果中的所有值均为 2.0f
。
BroadcastInDim
另请参阅 XlaBuilder::BroadcastInDim
。
通过复制数组中的数据来扩大数组的大小和阶数。
BroadcastInDim(operand, out_dim_size, broadcast_dimensions)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要复制的数组 |
out_dim_size |
ArraySlice<int64> |
目标形状的尺寸 |
broadcast_dimensions |
ArraySlice<int64> |
操作数形状的每个维度对应于目标形状中的哪个维度 |
与广播类似,但允许在任何位置添加维度,并扩展大小为 1 的现有维度。
operand
将广播到 out_dim_size
所描述的形状。broadcast_dimensions
会将 operand
的维度映射到目标形状的维度,即操作数的第 i 个维度会映射到输出形状的 broadcast_dimension[i]维度。operand
的维度必须为 1 或与其映射到的输出形状中的维度相同。其余尺寸会填入尺寸为 1 的尺寸。然后,退化维度广播会沿这些退化维度广播,以达到输出形状。广播页面详细介绍了这些语义。
致电
另请参阅 XlaBuilder::Call
。
使用给定参数调用计算。
Call(computation, args...)
参数 | 类型 | 语义 |
---|---|---|
computation |
XlaComputation |
包含 N 个任意类型的参数的 T_0, T_1, ..., T_{N-1} -> S 类型的计算 |
args |
一系列 N 个 XlaOp |
任意类型的 N 个实参 |
args
的 arity 和类型必须与 computation
的参数相匹配。它可以不包含 args
。
Cholesky
另请参阅 XlaBuilder::Cholesky
。
计算一组对称(赫尔米特)正定矩阵的 Cholesky 分解。
Cholesky(a, lower)
参数 | 类型 | 语义 |
---|---|---|
a |
XlaOp |
复杂类型或浮点类型的秩大于 2 的数组。 |
lower |
bool |
是否使用 a 的上三角或下三角。 |
如果 lower
为 true
,则计算下三角矩阵 l
,使 $a = l$。l^T$。如果 lower
为 false
,则计算上三角矩阵 u
,使得\(a = u^T . u\)。
系统仅从 a
的下/上三角形读取输入数据,具体取决于 lower
的值。其他三角形的值会被忽略。输出数据会在同一三角形中返回;另一个三角形中的值由实现定义,可以是任何值。
如果 a
的秩大于 2,则 a
会被视为一批矩阵,其中除了次要 2 个维度之外,所有维度都是批量维度。
如果 a
不是对称(赫尔米特)正定矩阵,则结果由实现定义。
夹子
另请参阅 XlaBuilder::Clamp
。
将操作数限制在最小值和最大值之间的范围内。
Clamp(min, operand, max)
参数 | 类型 | 语义 |
---|---|---|
min |
XlaOp |
类型为 T 的数组 |
operand |
XlaOp |
类型为 T 的数组 |
max |
XlaOp |
类型为 T 的数组 |
给定运算数以及最小值和最大值,当操作数在最小值和最大值之间时返回最小值,如果操作数低于此范围则返回最小值,如果操作数超出此范围则返回最大值。即 clamp(a, x, b) = min(max(a, x), b)
。
这三个数组的形状必须相同。或者,作为广播的一种受限形式,min
和/或 max
可以是类型为 T
的标量。
使用标量 min
和 max
的示例:
let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};
收起
另请参阅 XlaBuilder::Collapse
和 tf.reshape
运算。
将数组的维度合并为一维。
Collapse(operand, dimensions)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 的数组 |
dimensions |
int64 矢量 |
T 维度的有序、连续子集。 |
收起会将运算对象维度的给定子集替换为单个维度。输入参数是类型为 T 的任意数组,以及维度索引的编译时常量矢量。维度索引必须是 T 维度的连续子集,且必须按顺序(从低到高维度编号)排列。因此,{0, 1, 2}、{0, 1} 或 {1, 2} 都是有效的维度组合,但 {1, 0} 或 {0, 2} 则不是。这些维度会被一个新维度取代,且新维度在维度序列中的位置与所替换维度相同,新维度大小等于原始维度大小的乘积。dimensions
中的维度编号越低,表示在用于收起这些维度的循环嵌套中,该维度的变化越慢(最主要),维度编号越高,表示该维度的变化越快(最次要)。如果需要更通用的收起排序,请参阅 tf.reshape
运算符。
例如,让 v 成为包含 24 个元素的数组:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};
// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };
// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };
CollectivePermute
另请参阅 XlaBuilder::CollectivePermute
。
CollectivePermute 是一种集合操作,用于跨副本发送和接收数据。
CollectivePermute(operand, source_target_pairs)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
维输入数组 |
source_target_pairs |
<int64, int64> 矢量 |
(source_replica_id, target_replica_id) 对的列表。对于每对操作数,系统都会将操作数从源副本发送到目标副本。 |
请注意,source_target_pair
存在以下限制:
- 任何两个对都不应具有相同的目标副本 ID,并且来源副本 ID 也不应相同。
- 如果某个副本 ID 不是任何对中的目标,则该副本上的输出是一个由 0 组成的张量,其形状与输入相同。
串联
另请参阅 XlaBuilder::ConcatInDim
。
串联用于从多个数组运算对象组合数组。该数组的秩与每个输入数组运算数(必须具有相同的秩)相同,并且包含按指定顺序的参数。
Concatenate(operands..., dimension)
参数 | 类型 | 语义 |
---|---|---|
operands |
一系列 N 个 XlaOp |
类型为 T 且维度为 [L0, L1, ...] 的 N 个数组。要求 N >= 1。 |
dimension |
int64 |
一个介于 [0, N) 之间的值,用于为要在 operands 之间串联的维度命名。 |
除了 dimension
之外,所有维度都必须相同。这是因为 XLA 不支持“带有空格”的数组。另请注意,排名为 0 的值无法串联(因为无法为串联的维度命名)。
一维示例:
Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}
二维示例:
let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}
示意图:
基于条件
另请参阅 XlaBuilder::Conditional
。
Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
参数 | 类型 | 语义 |
---|---|---|
pred |
XlaOp |
类型为 PRED 的标量 |
true_operand |
XlaOp |
类型为 \(T_0\)的参数 |
true_computation |
XlaComputation |
\(T_0 \to S\)类型的 XlaComputation |
false_operand |
XlaOp |
类型为 \(T_1\)的参数 |
false_computation |
XlaComputation |
类型为 \(T_1 \to S\)的 XlaComputation |
如果 pred
为 true
,则执行 true_computation
;如果 pred
为 false
,则执行 false_computation
,并返回结果。
true_computation
必须接受 \(T_0\) 类型的单个参数,并将通过 true_operand
调用(必须为同一类型)。false_computation
必须接受 \(T_1\) 类型的单个参数,并将通过 false_operand
(必须为同一类型)调用。true_computation
和 false_computation
返回值的类型必须相同。
请注意,根据 pred
的值,系统仅会执行 true_computation
和 false_computation
中的一个。
Conditional(branch_index, branch_computations, branch_operands)
参数 | 类型 | 语义 |
---|---|---|
branch_index |
XlaOp |
类型为 S32 的标量 |
branch_computations |
N 个XlaComputation 的序列 |
类型为 \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)的 XlaComputation |
branch_operands |
N 个XlaOp 的序列 |
\(T_0 , T_1 , ..., T_{N-1}\)类型的参数 |
执行 branch_computations[branch_index]
并返回结果。如果 branch_index
是一个小于 0 或大于等于 N 的 S32
,则会将 branch_computations[N-1]
作为默认分支进行执行。
每个 branch_computations[b]
都必须接受 \(T_b\) 类型的一个参数,并将通过 branch_operands[b]
调用(必须为同一类型)。每个 branch_computations[b]
的返回值类型必须相同。
请注意,系统只会根据 branch_index
的值执行其中一个 branch_computations
。
转化(卷积)
另请参阅 XlaBuilder::Conv
。
与 ConvWithGeneralPadding 相同,但以简写方式(SAME 或 VALID)指定了填充。SAME 填充会使用零填充输入 (lhs
),以便在不考虑步长的情况下,输出与输入具有相同的形状。VALID 填充表示没有填充。
ConvWithGeneralPadding(卷积)
另请参阅 XlaBuilder::ConvWithGeneralPadding
。
计算神经网络中所用类型的卷积。在这里,卷积可以被视为在 n 维基准区域中移动的 n 维窗口,并且系统会针对窗口的每个可能位置执行计算。
参数 | 类型 | 语义 |
---|---|---|
lhs |
XlaOp |
秩 n+2 输入数组 |
rhs |
XlaOp |
对内核权重数组进行排序(n+2 个元素) |
window_strides |
ArraySlice<int64> |
内核步长的 n 维数组 |
padding |
ArraySlice< pair<int64,int64>> |
一个包含(低、高)内边距的 n 维数组 |
lhs_dilation |
ArraySlice<int64> |
n-d 左侧边界扩张因子数组 |
rhs_dilation |
ArraySlice<int64> |
n-d 右侧边界膨胀因子数组 |
feature_group_count |
int64 | 特征组的数量 |
batch_group_count |
int64 | 批处理组的数量 |
设 n 为空间维度数。lhs
参数是一个秩为 n+2 的数组,用于描述基底区域。这称为输入,尽管 rhs 当然也是输入。在神经网络中,这些是输入激活。n+2 个维度按以下顺序排列:
batch
:此维度中的每个坐标都代表一个要执行卷积的独立输入。z/depth/features
:基准区域中的每个 (y,x) 位置都与一个矢量相关联,该矢量会进入此维度。spatial_dims
:描述n
空间维度,用于定义窗口移动的基本区域。
rhs
参数是一个秩为 n+2 的数组,用于描述卷积滤波器/核/窗口。维度如下:
output-z
:输出的z
维度。input-z
:此维度的大小乘以feature_group_count
应等于 lhs 中的z
维度的大小。spatial_dims
:描述用于定义在基准区域中移动的 n 维窗口的n
空间维度。
window_strides
参数用于指定空间维度中的卷积窗口的步长。例如,如果第一个空间维度的步长为 3,则窗口只能放置在第一个空间索引可被 3 整除的坐标处。
padding
参数用于指定要应用于基准区域的零内边距量。填充量可以为负值,负值填充的绝对值表示在执行卷积之前要从指定维度中移除的元素数量。padding[0]
用于指定尺寸 y
的内边距,padding[1]
用于指定尺寸 x
的内边距。每个对的第一个元素是下内边距,第二个元素是上内边距。低内边距会应用在索引较低的方向,而高内边距会应用在索引较高的方向。例如,如果 padding[1]
为 (2,3)
,则第二个空间维度将在左侧添加 2 个零,在右侧添加 3 个零作为内边距。使用填充等同于在进行卷积之前将这些相同的零值插入输入 (lhs
)。
lhs_dilation
和 rhs_dilation
参数分别指定要应用于每个空间维度的左侧和右侧的膨胀因子。如果空间维度的膨胀系数为 d,则会在该维度的每个条目之间隐式放置 d-1 个孔,从而增加数组的大小。这些空洞会填充无操作值,对于卷积,这意味着填充零值。
rhs 扩张也称为 atrous 卷积。如需了解详情,请参阅 tf.nn.atrous_conv2d
。左侧的膨胀也称为转置卷积。如需了解详情,请参阅 tf.nn.conv2d_transpose
。
feature_group_count
参数(默认值为 1)可用于分组卷积。feature_group_count
需要同时是输入特征维度和输出特征维度的除数。如果 feature_group_count
大于 1,则表示从概念上讲,输入和输出特征维度以及 rhs
输出特征维度被平均拆分为多个 feature_group_count
组,每组由特征的连续子序列组成。rhs
的输入特征维度需要等于 lhs
输入特征维度除以 feature_group_count
(因此它已经具有一组输入特征的大小)。将第 i 组组合使用,可为许多单独的卷积计算 feature_group_count
。这些卷积的结果会在输出特征维度中串联在一起。
对于深度卷积,feature_group_count
参数将设置为输入特征维度,并且过滤器将从 [filter_height, filter_width, in_channels, channel_multiplier]
重塑为 [filter_height, filter_width, 1, in_channels * channel_multiplier]
。如需了解详情,请参阅 tf.nn.depthwise_conv2d
。
batch_group_count
(默认值 1)参数可用于反向传播算法期间的已分组过滤器。batch_group_count
必须是 lhs
(输入)批量维度大小的除数。如果 batch_group_count
大于 1,则表示输出批次维度应为大小为 input batch
/ batch_group_count
的向量。batch_group_count
必须是输出特征大小的分母。
输出形状具有以下维度,按以下顺序排列:
batch
:此维度的大小乘以batch_group_count
应等于 lhs 中的batch
维度的大小。z
:与内核上的output-z
相同大小 (rhs
)。spatial_dims
:对于卷积窗口的每个有效放置,都有一个值。
上图展示了 batch_group_count
字段的运作方式。实际上,我们会将每个 lhs 批处理切片为 batch_group_count
组,并对输出特征执行相同的操作。然后,对于这些组中的每个组,我们都会进行对偶卷积,并沿输出特征维度串联输出。所有其他维度(特征和空间)的操作语义保持不变。
卷积窗口的有效位置由步长和填充后底区域的大小决定。
为了说明卷积的运作方式,我们考虑一个二维卷积,并在输出中选择一些固定的 batch
、z
、y
、x
坐标。然后,(y,x)
是基础区域内窗口一角的位置(例如左上角,具体取决于您如何解释空间维度)。现在,我们有一个从基准区域获取的二维窗口,其中每个二维点都与一个一维矢量相关联,因此我们得到了一个三维盒子。在卷积核中,由于我们固定了输出坐标 z
,因此我们还拥有一个 3D 盒子。这两个盒子具有相同的维度,因此我们可以对这两个盒子之间的元素级乘积求和(类似于点积)。这就是输出值。
请注意,如果 output-z
为5,则窗口的每个位置都会在输出的 z
维度中生成 5 个值。这些值的不同之处在于所使用的卷积核的部分 - 每个 output-z
坐标都有一个单独的 3D 值盒。因此,您可以将其视为 5 个单独的卷积,每个卷积都有不同的滤镜。
以下是带有填充和步长的 2D 卷积的伪代码:
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
ConvertElementType
另请参阅 XlaBuilder::ConvertElementType
。
与 C++ 中的元素级 static_cast
类似,执行从数据形状到目标形状的元素级转换操作。维度必须匹配,并且是元素级转换;例如,s32
元素通过 s32
到 f32
的转换例程变为 f32
元素。
ConvertElementType(operand, new_element_type)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 且维度为 D 的数组 |
new_element_type |
PrimitiveType |
类型 U |
运算数和目标形状的维度必须匹配。源和目标元素类型不得为元组。
诸如 T=s32
到 U=f32
的转换将执行标准化 int 到浮点数转换例程,例如舍入到最近偶数。
let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}
CrossReplicaSum
使用求和计算执行 AllReduce
。
CustomCall
另请参阅 XlaBuilder::CustomCall
。
在计算中调用用户提供的函数。
CustomCall(target_name, args..., shape)
参数 | 类型 | 语义 |
---|---|---|
target_name |
string |
函数的名称。系统将发出一条针对此符号名称的调用指令。 |
args |
一系列 N 个 XlaOp |
任意类型的 N 个参数,将传递给函数。 |
shape |
Shape |
函数的输出形状 |
无论参数的个数或类型如何,函数签名都是相同的:
extern "C" void target_name(void* out, void** in);
例如,如果按如下方式使用 CustomCall:
let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };
CustomCall("myfunc", {x, y}, f32[3x3])
下面是 myfunc
实现的示例:
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
用户提供的函数不得有副作用,并且其执行必须是幂等的。
点
另请参阅 XlaBuilder::Dot
。
Dot(lhs, rhs)
参数 | 类型 | 语义 |
---|---|---|
lhs |
XlaOp |
类型为 T 的数组 |
rhs |
XlaOp |
类型为 T 的数组 |
此运算的确切语义取决于运算元的阶数:
输入 | 输出 | 语义 |
---|---|---|
矢量 [n] dot 矢量 [n] |
标量 | 矢量点积 |
矩阵 [m x k] dot 矢量 [k] |
矢量 [m] | 矩阵-向量乘法 |
矩阵 [m x k] dot 矩阵 [k x n] |
矩阵 [m x n] | 矩阵-矩阵乘法 |
该运算会对 lhs
的第二个维度(如果其秩为 1,则为第一个维度)和 rhs
的第一个维度执行乘积求和。这些是“收缩”的维度。lhs
和 rhs
的收缩维度必须具有相同的大小。在实践中,它可用于执行矢量之间的点积、矢量/矩阵乘法或矩阵/矩阵乘法。
DotGeneral
另请参阅 XlaBuilder::DotGeneral
。
DotGeneral(lhs, rhs, dimension_numbers)
参数 | 类型 | 语义 |
---|---|---|
lhs |
XlaOp |
类型为 T 的数组 |
rhs |
XlaOp |
类型为 T 的数组 |
dimension_numbers |
DotDimensionNumbers |
合同和批量维度编号 |
与 Dot 类似,但允许为 lhs
和 rhs
指定收缩和批量维度编号。
DotDimensionNumbers 字段 | 类型 | 语义 |
---|---|---|
lhs_contracting_dimensions
|
repeated int64 | lhs 缩减维度编号 |
rhs_contracting_dimensions
|
重复的 int64 | rhs 个合同维度编号 |
lhs_batch_dimensions
|
repeated int64 | lhs 批量维度编号 |
rhs_batch_dimensions
|
重复的 int64 | rhs 批量维度编号 |
DotGeneral 会按 dimension_numbers
中指定的合同规定的尺寸对产品求和。
lhs
和 rhs
中关联的收缩维度编号不必相同,但必须具有相同的维度大小。
缩略维度编号示例:
lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }
rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);
DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }
lhs
和 rhs
中关联的批次维度编号必须具有相同的维度大小。
包含批量维度数的示例(批量大小为 2,2x2 矩阵):
lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
输入 | 输出 | 语义 |
---|---|---|
[b0, m, k] dot [b0, k, n] |
[b0、m、n] | 批量 matmul |
[b0、b1、m、k] dot [b0、b1、k、n] |
[b0, b1, m, n] | 批处理 matmul |
因此,生成的维度编号从批处理维度开始,然后是 lhs
非收缩/非批处理维度,最后是 rhs
非收缩/非批处理维度。
DynamicSlice
另请参阅 XlaBuilder::DynamicSlice
。
DynamicSlice 从位于动态 start_indices
的输入数组中提取子数组。每个维度中的 slice 的大小会在 size_indices
中传递,其中指定了每个维度中不含边界值的 slice 间隔的结束点:[start, start + size)。start_indices
的形状必须为秩 == 1,且维度大小等于 operand
的排名。
DynamicSlice(operand, start_indices, size_indices)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
T 型 N 维数组 |
start_indices |
一系列 N 个 XlaOp |
包含每个维度 slice 的起始索引的 N 个标量整数的列表。值必须大于或等于零。 |
size_indices |
ArraySlice<int64> |
包含每个维度的 slice 大小的 N 个整数的列表。每个值都必须严格大于零,并且 start + size 必须小于或等于维度的大小,以避免模除维度大小的封装。 |
有效的 slice 索引是通过在执行 slice 之前对 [1, N)
中的每个索引 i
应用以下转换计算得出的:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
这可确保提取的 slice 始终相对于操作数数组在边界内。如果应用转换之前切片已进入边界,则转换不会产生任何影响。
一维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}
DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}
DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicUpdateSlice
另请参阅 XlaBuilder::DynamicUpdateSlice
。
DynamicUpdateSlice 会生成一条结果,该结果是输入数组 operand
的值,并且 Slice update
已被覆盖 start_indices
。update
的形状决定了要更新的结果的子数组的形状。start_indices
的形状必须为 rank == 1,且维度大小等于 operand
的 rank。
DynamicUpdateSlice(operand, update, start_indices)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 的 N 维数组 |
update |
XlaOp |
类型 T 的 N 维数组,其中包含切片更新。更新形状的每个维度都必须严格大于零,并且 start + update 必须小于或等于每个维度的运算数大小,以避免生成超出范围的更新索引。 |
start_indices |
N 个XlaOp 的序列 |
包含每个维度 slice 的起始索引的 N 个标量整数的列表。值必须大于或等于零。 |
在执行切片之前,通过对 [1, N)
中的每个索引 i
应用以下转换来计算有效切片索引:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
这可确保更新后的 slice 始终在操作数数组的边界内。如果在应用转换之前 slice 在边界内,则转换无效。
一维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}
DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0, 13.0},
{14.0, 15.0},
{16.0, 17.0} }
let s = {1, 1}
DynamicUpdateSlice(b, u, s) produces:
{ {0.0, 1.0, 2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }
元素级二元算术运算
另请参阅 XlaBuilder::Add
。
支持一组元素级二进制算术运算。
Op(lhs, rhs)
其中 Op
为以下之一:Add
(加法)、Sub
(减法)、Mul
(乘法)、Div
(除法)、Pow
(幂)、Rem
(余数)、Min
(最小值)、And
(逻辑 AND)、Or
(逻辑 OR)、Xor
(逻辑 XOR)、ShiftLeft
(左移或 ShiftRightArithmetic
(虚数右移)和 Complex
(虚数右移)ShiftRightLogical
(实数右移)、ShiftRightLogical
(实数右移)Max
Atan2
参数 | 类型 | 语义 |
---|---|---|
lhs |
XlaOp |
左侧操作数:T 型数组 |
rhs |
XlaOp |
右侧操作数:类型为 T 的数组 |
参数的形状必须相似或兼容。如需了解形状兼容的含义,请参阅广播文档。操作的结果的形状是广播两个输入数组的结果。此变体不支持不同秩的数组之间的运算,除非其中一个运算数是标量。
当 Op
为 Rem
时,从被除数中提取结果的符号,且结果的绝对值始终小于除数的绝对值。
整数除法溢出(对零进行有符号/无符号除法/求余,或对 INT_SMIN
与 -1
进行有符号除法/求余)会产生实现定义的值。
以下操作具有支持不同秩广播的替代变体:
Op(lhs, rhs, broadcast_dimensions)
其中,Op
与上面相同。此运算变体应用于不同秩数的数组之间的算术运算(例如向向量添加矩阵)。
额外的 broadcast_dimensions
运算数是一个整数切片,用于将低阶运算数的秩扩展到高阶运算数的秩。broadcast_dimensions
会将低阶形状的维度映射到高阶形状的维度。已展开形状的未映射尺寸会填充大小为 1 的尺寸。然后,退化维度广播会沿这些退化维度广播形状,以使两个运算数的形状相等。广播页面详细介绍了这些语义。
元素级比较运算
另请参阅 XlaBuilder::Eq
。
支持一组标准的元素级二元比较运算。请注意,在比较浮点类型时,标准 IEEE 754 浮点比较语义适用。
Op(lhs, rhs)
其中,Op
是 Eq
(等于)、Ne
(不等于)、Ge
(大于或等于)、Gt
(大于)、Le
(小于或等于)、Lt
(小于)之一。另一组运算符(EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder)提供相同的功能,但它们还支持对浮点数进行总排序,方法是强制执行 -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN。
参数 | 类型 | 语义 |
---|---|---|
lhs |
XlaOp |
左侧操作数:T 型数组 |
rhs |
XlaOp |
右侧操作数:类型为 T 的数组 |
参数的形状必须相似或兼容。如需了解形状兼容的含义,请参阅广播文档。操作结果具有的形状是广播两个元素类型为 PRED
的输入数组的结果。在此变体中,不支持不同秩数数组之间的运算,除非其中一个运算数是标量。
对于这些操作,存在支持不同等级广播的替代变体:
Op(lhs, rhs, broadcast_dimensions)
其中,Op
与上面相同。此运算变体应用于不同秩的数组之间的比较运算(例如,向向量添加矩阵)。
额外的 broadcast_dimensions
运算数是整数切片,用于指定用于广播运算数的维度。广播页面上详细介绍了语义。
元素级一元函数
XlaBuilder 支持以下元素级一元函数:
Abs(operand)
元素级绝对 x -> |x|
。
Cbrt(operand)
元素级立方根运算 x -> cbrt(x)
。
Ceil(operand)
按元素取天花板值 x -> ⌈x⌉
。
Clz(operand)
按元素统计前导零数。
Cos(operand)
按元素计算的余弦 x -> cos(x)
。
Erf(operand)
元素级误差函数 x -> erf(x)
,其中
\(\text{erf}(x) = \frac{2}{\sqrt{\pi} }\int_0^x e^{-t^2} \, dt\)。
Exp(operand)
元素级自然指数 x -> e^x
。
Expm1(operand)
按元素计算自然指数减一 x -> e^x - 1
。
Floor(operand)
按元素取整 x -> ⌊x⌋
。
Imag(operand)
复杂(或实)形状的元素级虚部。x -> imag(x)
。如果操作数是浮点类型,则返回 0。
IsFinite(operand)
测试 operand
的每个元素是否有限,即非正无穷大或负无穷大,以及是否非 NaN
。返回一个与输入相同形状的 PRED
值数组,其中每个元素都为 true
,且仅当相应的输入元素为有限值时才为 true
。
Log(operand)
按元素计算的自然对数 x -> ln(x)
。
Log1p(operand)
按元素偏移的自然对数 x -> ln(1+x)
。
Logistic(operand)
元素级逻辑函数计算 x ->
logistic(x)
。
Neg(operand)
元素级否定 x -> -x
。
Not(operand)
元素级逻辑非 x -> !(x)
。
PopulationCount(operand)
计算 operand
的每个元素中设置的位数。
Real(operand)
复杂(或实数)形状的元素级实部。x -> real(x)
。如果操作数是浮点类型,则返回相同的值。
Round(operand)
按元素舍入,平分时向远离 0 的方向舍入。
RoundNearestEven(operand)
元素级舍入,接近最接近的偶数。
Rsqrt(operand)
对平方根运算 x -> 1.0 / sqrt(x)
执行元素级反向运算。
Sign(operand)
元素级符号运算 x -> sgn(x)
,其中
\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]
使用 operand
元素类型的比较运算符。
Sin(operand)
按元素计算的正弦 x -> sin(x)
。
Sqrt(operand)
元素级平方根运算 x -> sqrt(x)
。
Tan(operand)
元素级正切 x -> tan(x)
。
Tanh(operand)
元素级双曲正切 x -> tanh(x)
。
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
函数的操作数 |
该函数会应用于 operand
数组中的每个元素,从而生成形状相同的数组。operand
可以是标量(秩为 0)。
福音
XLA FFT 运算会针对实数和复杂输入/输出实现正向和反向傅里叶变换。支持对最多 3 个轴进行多维 FFT。
另请参阅 XlaBuilder::Fft
。
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
我们要进行傅里叶变换的数组。 |
fft_type |
FftType |
请参见下表。 |
fft_length |
ArraySlice<int64> |
要转换的轴的时间域长度。由于 RFFT(fft_length=[16]) 的输出形状与 RFFT(fft_length=[17]) 相同,因此 IRFFT 特别需要执行此操作,以便将最内侧轴调整为合适大小。 |
FftType |
语义 |
---|---|
FFT |
转发复杂到复杂的 FFT。形状保持不变。 |
IFFT |
复杂数到复杂数的反 FFT。形状保持不变。 |
RFFT |
正向实数到复数 FFT。如果 fft_length[-1] 为非零值,则最内轴的形状会缩减为 fft_length[-1] // 2 + 1 ,并省略已转换信号中超出奈奎斯特频率的反转共振部分。 |
IRFFT |
实数到复数的 FFT 的逆运算(即接受复数,返回实数)。如果 fft_length[-1] 为非零值,则最内轴的形状会展开为 fft_length[-1] ,从而根据 1 与 fft_length[-1] // 2 + 1 条目的反向共振推断出超过奈奎斯特频率的转换信号部分。 |
多维 FFT
如果提供多个 fft_length
,则相当于对最内侧的每个轴应用一系列 FFT 操作。请注意,对于实数到复数和复数到实数的情况,最内层轴变换(实际上)会先执行(RFFT;IRFFT 为最后),因此最内层轴会发生大小变化。然后,其他轴转换将是复杂->复杂。
实现细节
CPU FFT 由 Eigen 的 TensorFFT 提供支持。GPU FFT 使用 cuFFT。
Gather
XLA 汇总操作会将输入数组的多个切片(每个切片的运行时偏移可能不同)缝合在一起。
一般语义学
另请参阅 XlaBuilder::Gather
。
如需了解更直观的说明,请参阅下文中的“非正式说明”部分。
gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
我们要从中收集数据的数组。 |
start_indices |
XlaOp |
包含我们收集的切片的起始索引的数组。 |
index_vector_dim |
int64 |
start_indices 中“包含”起始索引的维度。请参阅下文了解详情。 |
offset_dims |
ArraySlice<int64> |
输出形状中偏移到从运算数切片的数组中的一组维度。 |
slice_sizes |
ArraySlice<int64> |
slice_sizes[i] 是维度 i 的切片的边界。 |
collapsed_slice_dims |
ArraySlice<int64> |
每个 Slice 中被收起的一组维度。这些维度必须为大小为 1 的数组。 |
start_index_map |
ArraySlice<int64> |
一个映射,描述了如何将 start_indices 中的索引映射到操作数中的合法索引。 |
indices_are_sorted |
bool |
是否保证调用方对索引进行排序。 |
为方便起见,我们将输出数组中不在 offset_dims
中的维度标记为 batch_dims
。
输出是秩 batch_dims.size
+ offset_dims.size
的数组。
operand.rank
必须等于 offset_dims.size
和 collapsed_slice_dims.size
的总和。此外,slice_sizes.size
必须等于 operand.rank
。
如果 index_vector_dim
等于 start_indices.rank
,我们会隐式认为 start_indices
具有尾随 1
维度(即,如果 start_indices
的形状为 [6,7]
且 index_vector_dim
为 2
,则我们隐式认为 start_indices
的形状为 [6,7,1]
)。
沿维度 i
的输出数组的边界计算方式如下:
如果
batch_dims
中存在i
(即对于某些k
,等于batch_dims[k]
),则我们会从start_indices.shape
中选择相应的维度边界,跳过index_vector_dim
(即如果k
<index_vector_dim
,则选择start_indices.shape.dims
[k
];否则,选择start_indices.shape.dims
[k
+1
])。如果
i
存在于offset_dims
中(即对于某个k
,等于offset_dims
[k
]),则我们会在考虑collapsed_slice_dims
后从slice_sizes
中选择相应的边界(即我们选择adjusted_slice_sizes
[k
],其中adjusted_slice_sizes
是移除了索引collapsed_slice_dims
的边界的slice_sizes
)。
正式地,与给定输出编号 Out
对应的操作数编号 In
的计算方式如下:
设
G
= {Out
[k
] fork
inbatch_dims
}。使用G
切片出矢量S
,使得S
[i
] =start_indices
[Combine(G
,i
)],其中 Combine(A, b) 会将 b 插入到 A 的index_vector_dim
位置。请注意,即使G
为空,也明确定义:如果G
为空,则S
=start_indices
。使用
start_index_map
将S
分散到operand
中,然后使用S
在operand
中创建起始索引S
in
。更具体地说:S
如果k
<start_index_map.size
,则in
[start_index_map
[k
]] =S
[k
]。S
in
[_
] =0
,否则。
通过根据
collapsed_slice_dims
集在Out
中的偏移维度上分散索引,在operand
中创建索引O
in
。更具体地说:O
如果k
<offset_dims.size
(remapped_offset_dims
在下文中定义),则in
[remapped_offset_dims
(k
)] =Out
[offset_dims
[k
]]。O
否则,in
[_
] =0
。
In
是O
in
+S
in
,其中 + 是元素级加法。
remapped_offset_dims
是一个单调函数,其域为 [0
, offset_dims.size
),范围为 [0
, operand.rank
) \ collapsed_slice_dims
。因此,例如,offset_dims.size
为 4
,operand.rank
为 6
,collapsed_slice_dims
为 {0
, 2
},然后 remapped_offset_dims
为 {0
→1
,
1
→3
, 2
→4
, 3
→5
}。
如果 indices_are_sorted
设置为 true,则 XLA 可以假定 start_indices
已由用户排序(按升序排序,在根据 start_index_map
散布其值之后)。否则,语义就是实现定义的。
非正式说明和示例
非正式地,输出数组中的每个索引 Out
都对应于运算数数组中的元素 E
,计算方式如下:
我们使用
Out
中的批处理维度从start_indices
中查找起始索引。我们使用
start_index_map
将起始索引(其大小可能小于 operand.rank)映射到operand
中的“full”起始索引。我们使用完整起始索引对大小为
slice_sizes
的切片进行动态切片。我们通过收起
collapsed_slice_dims
维度来调整 slice 的形状。由于所有已收起的 slice 维度都必须具有 1 的边界,因此此重塑始终是合法的。我们使用
Out
中的偏移量维度对此 slice 进行编制索引,以获取与输出索引Out
对应的输入元素E
。
在以下所有示例中,index_vector_dim
均设置为 start_indices.rank
-1
。index_vector_dim
的更有趣的值不会从根本上改变该操作,但会使视觉表示更复杂。
为了直观地了解上述所有内容如何协同工作,我们来看一个示例,该示例从 [16,11]
数组中收集 5 个形状为 [8,6]
的 slice。[16,11]
数组中 slice 的位置可以表示为形状为 S64[2]
的索引矢量,因此这组 5 个位置可以表示为 S64[5,2]
数组。
然后,汇总操作的行为可以描述为一个索引转换,该转换接受 [G
,O
0
,O
1
](输出形状中的索引),并按如下方式将其映射到输入数组中的元素:
我们首先使用 G
从汇总索引数组中选择一个 (X
,Y
) 矢量。然后,输出数组中索引为 [G
,O
0
,O
1
] 的元素就是输入数组中索引为 [X
+O
0
,Y
+O
1
] 的元素。
slice_sizes
为 [8,6]
,它决定了 O0
和 O1
的范围,而这反过来又决定了 slice 的边界。
此收集操作充当批量动态切片,使用 G
作为批量维度。
汇总索引可以是多维的。例如,上述示例的更通用版本使用形状为 [4,5,2]
的“gather 索引”数组,会按如下方式转换索引:
同样,这会充当批量动态 slice G
0
,并将 G
1
用作批量维度。切片大小仍为 [8,6]
。
XLA 中的汇总操作会以以下方式对上述非正式语义进行泛化:
我们可以配置输出形状中的哪些维度是偏移维度(上一个示例中包含
O
0
、O
1
的维度)。输出批次维度(上例中包含G
0
、G
1
的维度)定义为非偏移维度的输出维度。输出形状中明确存在的输出偏移量维数的数量可能小于输入秩。这些“缺失”维度(明确列为
collapsed_slice_dims
)的 slice 大小必须为1
。由于它们的 slice 大小为1
,因此唯一有效的索引是0
,并且省略它们不会引入歧义。从“Gather Indices”(上例中的 (
X
,Y
))数组中提取的切片可能包含的元素数量少于输入数组的秩,而显式映射决定了索引应如何扩展以与输入具有相同的秩。
最后,我们使用 (2) 和 (3) 来实现 tf.gather_nd
:
G
0
和 G
1
像往常一样用于从收集索引数组中切除起始索引,但起始索引只有一个元素,即 X
。同样,只有一个输出偏移量索引,其值为 O
0
。然而,在用作输入数组的索引之前,[[G
s] [将 [G
s] 在用作输入数组的索引中,这些 [G
的 Index Mapping(“Gather Index Mapping”(它的start_index_map
remapped_offset_dims
X
X
0
0
0
0
0
0
0
0
0
O
O
O
O
G
1
1
GatherIndices
tf.gather_nd
此支持请求的slice_sizes
为 [1,11]
。直观地讲,这意味着汇总索引数组中的每个索引 X
都会选择一整行,而结果是所有这些行的串联。
GetDimensionSize
另请参阅 XlaBuilder::GetDimensionSize
。
返回运算数的给定维度的大小。运算数必须为数组形状。
GetDimensionSize(operand, dimension)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
维输入数组 |
dimension |
int64 |
一个介于 [0, n) 之间的值,用于指定维度 |
SetDimensionSize
另请参阅 XlaBuilder::SetDimensionSize
。
设置 XlaOp 给定维度的动态大小。运算数必须为数组形式。
SetDimensionSize(operand, size, dimension)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
n 维输入数组。 |
size |
XlaOp |
表示运行时动态大小的 int32。 |
dimension |
int64 |
[0, n) 区间内用于指定维度的值。 |
将运算数作为结果传递,并由编译器跟踪动态维度。
下游缩减运算会忽略填充的值。
let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;
// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);
// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);
GetTupleElement
另请参阅 XlaBuilder::GetTupleElement
。
对具有编译时常量值的元组编入索引。
该值必须是编译时常量,以便形状推理可以确定生成的值的类型。
这类似于 C++ 中的 std::get<int N>(t)
。从概念上讲:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
另请参阅 tf.tuple
。
Infeed
另请参阅 XlaBuilder::Infeed
。
Infeed(shape)
参数 | 类型 | 语义 |
---|---|---|
shape |
Shape |
从信息流接口读取的数据的形状。形状的布局字段必须设置为与发送到设备的数据的布局一致;否则,其行为将不明确。 |
从设备的隐式信息流流接口读取单个数据项,将数据解释为给定形状及其布局,并返回数据的 XlaOp
。计算中允许有多个 infeed 操作,但这些 infeed 操作之间必须有总顺序。例如,由于 while 循环之间存在依赖关系,因此以下代码中的两个信息流具有总顺序。
result1 = while (condition, init = init_value) {
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
}
不支持嵌套元组形状。对于空元组形状,插播操作实际上是无操作,并且会继续执行,而不会从设备的插播内容中读取任何数据。
Iota
另请参阅 XlaBuilder::Iota
。
Iota(shape, iota_dimension)
在设备上构建常量字面量,而不是可能庞大的主机传输。创建一个具有指定形状的数组,其中包含从零开始并沿指定维度递增 1 的值。对于浮点类型,生成的数组等同于 ConvertElementType(Iota(...))
,其中 Iota
为整数类型,并且转换为浮点类型。
参数 | 类型 | 语义 |
---|---|---|
shape |
Shape |
Iota() 创建的数组的形状 |
iota_dimension |
int64 |
要递增的维度。 |
例如,Iota(s32[4, 8], 0)
会返回
[[0, 0, 0, 0, 0, 0, 0, 0 ],
[1, 1, 1, 1, 1, 1, 1, 1 ],
[2, 2, 2, 2, 2, 2, 2, 2 ],
[3, 3, 3, 3, 3, 3, 3, 3 ]]
退货费用 Iota(s32[4, 8], 1)
[[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ]]
地图
另请参阅 XlaBuilder::Map
。
Map(operands..., computation)
参数 | 类型 | 语义 |
---|---|---|
operands |
一系列 N 个 XlaOp |
N 个类型为 T0..T{N-1} 的数组 |
computation |
XlaComputation |
类型为 T_0, T_1, .., T_{N + M -1} -> S 的计算,其中包含类型为 T 的 N 个参数和类型为任意的 M 个参数 |
dimensions |
int64 数组 |
地图尺寸数组 |
对给定的 operands
数组应用标量函数,生成一个维度相同的数组,其中每个元素都是应用于输入数组中相应元素的映射函数的结果。
映射的函数是任意计算,但有以下限制:它有 N 个标量类型 T
的输入,以及一个类型为 S
的输出。输出的维度与运算数相同,但元素类型 T 会替换为 S。
例如:Map(op1, op2, op3, computation, par1)
会将 elem_out <-
computation(elem1, elem2, elem3, par1)
映射到输入数组中的每个(多维)索引,以生成输出数组。
OptimizationBarrier
阻止任何优化阶段将计算移出屏障。
确保在任何依赖于屏障输出的运算符之前评估所有输入。
垫子
另请参阅 XlaBuilder::Pad
。
Pad(operand, padding_value, padding_config)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
T 类型的数组 |
padding_value |
XlaOp |
T 类型的标量,用于填充添加的内边距 |
padding_config |
PaddingConfig |
两侧的内边距(低、高)以及每个维度元素之间的内边距 |
使用给定 padding_value
在数组周围以及数组元素之间添加内边距,从而扩展给定 operand
数组。padding_config
指定每个维度的边缘内边距和内部内边距。
PaddingConfig
是 PaddingConfigDimension
的重复字段,其中包含每个维度的三个字段:edge_padding_low
、edge_padding_high
和 interior_padding
。
edge_padding_low
和 edge_padding_high
分别指定在每个维度的低端(靠近编号 0)和高端(靠近最高编号)添加的填充量。边缘内边距可以是负数,负内边距的绝对值表示要从指定尺寸中移除的元素数量。
interior_padding
用于指定在每个维度中的任意两个元素之间添加的内边距;该值不得为负。从逻辑上讲,内部内边距会先于边缘内边距发生,因此在负边缘内边距的情况下,系统会从内边距补充的运算对象中移除元素。
如果边缘填充对全部 (0, 0) 且内部填充值均为 0,则此操作为空操作。下图显示了二维数组的不同 edge_padding
和 interior_padding
值的示例。
Recv
另请参阅 XlaBuilder::Recv
。
Recv(shape, channel_handle)
参数 | 类型 | 语义 |
---|---|---|
shape |
Shape |
要接收的数据的形状 |
channel_handle |
ChannelHandle |
每个发送/接收对的唯一标识符 |
从共用同一通道句柄的另一个计算中的 Send
指令接收给定形状的数据。针对已接收的数据返回 XlaOp。
Recv
操作的客户端 API 表示同步通信。不过,该指令会在内部分解为 2 个 HLO 指令(Recv
和 RecvDone
),以实现异步数据传输。另请参阅 HloInstruction::CreateRecv
和 HloInstruction::CreateRecvDone
。
Recv(const Shape& shape, int64 channel_id)
分配从具有相同 channel_id 的 Send
指令接收数据所需的资源。返回已分配资源的上下文,后续的 RecvDone
指令会使用该上下文来等待数据传输完成。上下文是 {接收缓冲区(形状)、请求标识符 (U32)} 的元组,只能由 RecvDone
指令使用。
RecvDone(HloInstruction context)
给定由 Recv
指令创建的上下文,等待数据传输完成并返回收到的数据。
减少
另请参阅 XlaBuilder::Reduce
。
将一个归约函数并行应用于一个或多个数组。
Reduce(operands..., init_values..., computation, dimensions)
参数 | 类型 | 语义 |
---|---|---|
operands |
N XlaOp 序列 |
类型为 T_0, ..., T_{N-1} 的 N 个数组。 |
init_values |
N XlaOp 序列 |
N 个 T_0, ..., T_{N-1} 类型的标量。 |
computation |
XlaComputation |
T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 类型的计算。 |
dimensions |
int64 数组 |
要缩减的维度的无序数组。 |
其中:
- N 必须大于或等于 1。
- 计算必须“大致”遵循结合律(见下文)。
- 所有输入数组的维度都必须相同。
- 所有初始值都必须在
computation
下形成一个标识。 - 如果为
N = 1
,则Collate(T)
为T
。 - 如果为
N > 1
,则Collate(T_0, ..., T_{N-1})
是T
类型的N
元素的元组。
此操作会将每个输入数组的一个或多个维度缩减为标量。每个返回的数组的秩为 rank(operand) - len(dimensions)
。该运算的输出为 Collate(Q_0, ..., Q_N)
,其中 Q_i
是类型为 T_i
的数组,其维度如下所述。
不同的后端可以重新关联求和计算。这可能会导致数值差异,因为加法等一些求和函数对浮点数而言不是协同的。不过,如果数据范围有限,对于大多数实际用途,浮点加法足以接近于具有关联性。
示例
使用值为 [10, 11,
12, 13]
的单个 1D 数组中的一个维度进行求和时,如果使用求和函数 f
(即 computation
),则可以按如下方式计算:
f(10, f(11, f(12, f(init_value, 13)))
但还有许多其他可能性,例如
f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))
以下是一个粗略的伪代码示例,展示了如何使用求和作为初始值为 0 的求和计算来实现求和。
result_shape <- remove all dims in dimensions from operand_shape
# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
# Initialize this result element
result[r0, r1...] <- 0
# Iterate over all the reduction dimensions
for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
# Increment the result element with the value of the operand's element.
# The index of the operand's element is constructed from all ri's and di's
# in the right order (by construction ri's and di's together index over the
# whole operand shape).
result[r0, r1...] += operand[ri... di]
下面是一个对二维数组(矩阵)进行还原的示例。形状的秩为 2,维度 0 的大小为 2,维度 1 的大小为 3:
使用“add”函数减少维度 0 或 1 的结果:
请注意,这两个求和结果都是 1 维数组。为方便起见,该图将一个列显示为列,另一个显示为行。
下面是一个 3D 数组,是一个更复杂的示例。其秩为 3,维度 0 的大小为 4,维度 1 的大小为 2,维度 2 的大小为 3。为简单起见,值 1 到 6 会复制到维度 0。
与二维示例类似,我们只需减少一个维度。例如,如果将维度 0 缩减,就会得到一个 2 阶数组,其中维度 0 的所有值都被折叠成一个标量:
| 4 8 12 |
| 16 20 24 |
如果我们减少第 2 个维度,我们也会得到一个秩为 2 的数组,其中第 2 个维度的所有值都被折叠为标量:
| 6 15 |
| 6 15 |
| 6 15 |
| 6 15 |
请注意,输入中其余维度之间的相对顺序会保留在输出中,但某些维度可能会被分配新的编号(因为排名会发生变化)。
我们也可以缩减多个维度。减法维度 0 和 1 生成一维数组 [20, 28, 36]
。
对 3D 数组的所有维度进行求和会生成标量 84
。
可变参数求值
当 N > 1
时,reduce 函数应用稍微复杂一些,因为它会同时应用于所有输入。运算数按以下顺序提供给计算:
- 第一个运算元的运行减值
- …
- 针对第 N 个操作数运行缩减值
- 第一个操作数的输入值
- …
- 第 N 个操作数的输入值
例如,请考虑以下求和函数,该函数可用于并行计算 1 维数组的最大值和 argmax:
f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
if value >= max:
return (value, index)
else:
return (max, argmax)
对于 1 维输入数组 V = Float[N], K = Int[N]
和初始值 I_V = Float, I_K = Int
,对唯一输入维度进行求和的结果 f_(N-1)
等同于以下递归应用:
f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
对值数组和顺序索引数组(即 iota)应用此归约,将对数组进行协同迭代,并返回一个包含最大值和匹配索引的元组。
ReducePrecision
另请参阅 XlaBuilder::ReducePrecision
。
模拟将浮点值转换为精度较低的格式(例如 IEEE-FP16)并恢复为原始格式的影响。低精度格式中的指数和小数部分位数可以任意指定,但并非所有硬件实现都支持所有位大小。
ReducePrecision(operand, mantissa_bits, exponent_bits)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
浮点类型 T 的数组。 |
exponent_bits |
int32 |
较低精度格式的指数位数 |
mantissa_bits |
int32 |
低精度格式的小数部分位数 |
结果为 T
类型的数组。输入值会舍入为可用给定小数位数表示的最接近的值(使用“四舍五入”语义),并且任何超出指数位数指定范围的值都会被钳制为正无穷大或负无穷大。NaN
值会保留,但可能会转换为规范的 NaN
值。
精度较低的格式必须至少包含一个指数位(为了将零值与无穷大值区分开来,因为二者的尾数均为零),并且尾数位数必须为非负数。指数或尾数位的数量可能会超过 T
类型的相应值;那么相应转换部分就是一项空操作。
ReduceScatter
另请参阅 XlaBuilder::ReduceScatter
。
ReduceScatter 是一种集合操作,能够有效地执行 AllReduce,然后将结果分散到 shard_count
块中(沿着 scatter_dimension
拆分为多个 shard_count
块),副本组中的副本 i
会接收 ith
分片。
ReduceScatter(operand, computation, scatter_dim, shard_count,
replica_group_ids, channel_id)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要在各个副本之间进行求和的数组或数组的非空元组。 |
computation |
XlaComputation |
求余计算 |
scatter_dimension |
int64 |
要散射的维度。 |
shard_count |
int64 |
要拆分 scatter_dimension 的块数 |
replica_groups |
int64 的矢量向量 |
执行归约的组 |
channel_id |
可选的int64 |
用于跨模块通信的可选通道 ID |
- 当
operand
是数组元组时,会对元组的每个元素执行缩减-散点运算。 replica_groups
是执行缩减操作的副本组的列表(可以使用ReplicaId
检索当前副本的副本 ID)。每个组中的副本顺序决定了全局缩减结果的分散顺序。replica_groups
必须为空(在这种情况下,所有副本都属于单个组),或者包含与副本数量相同的元素。如果有多个副本组,则它们的大小必须相同。例如,replica_groups = {0, 2}, {1, 3}
会在副本0
和2
以及1
和3
之间执行归约,然后分散结果。shard_count
是每个副本组的大小。当replica_groups
为空时,我们需要此操作。如果replica_groups
不为空,则shard_count
必须等于每个副本组的大小。channel_id
用于跨模块通信:只有具有相同channel_id
的reduce-scatter
操作才能相互通信。
输出形状是输入形状,其中 scatter_dimension
缩小了 shard_count
倍。例如,如果有两个副本,并且操作数在两个副本上的值分别为 [1.0, 2.25]
和 [3.0, 5.25]
,那么对于 scatter_dim
为 0
的此运算,第一个副本的输出值为 [4.0]
,第二个副本的输出值为 [7.5]
。
ReduceWindow
另请参阅 XlaBuilder::ReduceWindow
。
将归约函数应用于 N 个多维数组序列的每个窗口中的所有元素,生成由 N 个多维数组的单个或一个元组作为输出。每个输出数组的元素数量与窗口的有效位置数量相同。池化层可以表示为 ReduceWindow
。与 Reduce
类似,应用的 computation
始终会传递给左侧的 init_values
。
ReduceWindow(operands..., init_values..., computation, window_dimensions,
window_strides, padding)
参数 | 类型 | 语义 |
---|---|---|
operands |
N XlaOps |
一系列类型为 T_0,..., T_{N-1} 的 N 个多维数组,每个数组都代表窗口放置的基准区域。 |
init_values |
N XlaOps |
归约的 N 个起始值,N 个操作数分别对应一个。如需了解详情,请参阅减少。 |
computation |
XlaComputation |
类型为 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 的求和函数,用于应用于所有输入运算数的每个窗口中的元素。 |
window_dimensions |
ArraySlice<int64> |
窗口维度值的整数数组 |
window_strides |
ArraySlice<int64> |
窗口步长值的整数数组 |
base_dilations |
ArraySlice<int64> |
基底膨胀值的整数数组 |
window_dilations |
ArraySlice<int64> |
窗口膨胀值的整数数组 |
padding |
Padding |
窗口的内边距类型(Padding::kSame,它会进行填充,以便在步长为 1 时实现与输入相同的输出形状;或者 Padding::kValid,它不使用内边距,并在窗口不再适合时“停止”窗口) |
其中:
- N 必须大于或等于 1。
- 所有输入数组必须具有相同的维度。
- 如果为
N = 1
,则Collate(T)
为T
。 - 如果为
N > 1
,Collate(T_0, ..., T_{N-1})
是一个由类型为(T0,...T{N-1})
的N
元素组成的元组。
下面的代码和图展示了使用 ReduceWindow
的示例。输入是一个大小为 [4x6] 的矩阵,并且 window_dimensions 和 window_stride_dimensions 均为 [2x3]。
// Create a computation for the reduction (maximum).
XlaComputation max;
{
XlaBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
max = builder.Build().value();
}
// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
input,
/*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
*max,
/*window_dimensions=*/{2, 3},
/*window_stride_dimensions=*/{2, 3},
Padding::kValid);
在某个维度中设置步长为 1 表示该维度中某个窗口的位置距离其相邻窗口 1 个元素。如需指定各个窗口互不重叠,window_stride_dimensions 应等于 window_dimensions。下图展示了使用两个不同步长值的情况。填充会应用于输入的每个维度,并且计算方式与输入以填充后的维度进入的情况相同。
举一个重要的填充示例,考虑使用维度 3
和输入数组 [10000, 1000, 100, 10, 1]
的步长 2
计算缩减窗口最小值(初始值为 MAX_FLOAT
)。填充 kValid
会计算两个有效窗口([10000, 1000, 100]
和 [100, 10, 1]
)中的最小值,从而生成输出 [100, 1]
。填充 kSame
首先会填充该数组,这样缩减窗口后的形状将与步长 1 的输入相同,方法是在两边添加初始元素,得到 [MAX_VALUE, 10000, 1000, 100, 10, 1,
MAX_VALUE]
。对内边距数组运行 reduce-window 会对三个窗口 [MAX_VALUE, 10000, 1000]
、[1000, 100, 10]
和 [10, 1, MAX_VALUE]
进行操作,并生成 [1000, 10, 1]
。
归约函数的求值顺序是任意的,并且可能是非确定性的。因此,求和函数不应对重新关联过于敏感。如需了解详情,请参阅 Reduce
上下文中关于关联性的讨论。
ReplicaId
另请参阅 XlaBuilder::ReplicaId
。
返回副本的唯一 ID(U32 标量)。
ReplicaId()
每个副本的唯一 ID 是介于 [0, N)
之间的无符号整数,其中 N
是副本数量。由于所有副本都在运行同一程序,因此程序中的 ReplicaId()
调用将在每个副本上返回不同的值。
Reshape
另请参阅 XlaBuilder::Reshape
和 Collapse
运算。
将数组的维度重新调整为新配置。
Reshape(operand, new_sizes)
Reshape(operand, dimensions, new_sizes)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 的数组 |
dimensions |
int64 矢量 |
尺寸的收起顺序 |
new_sizes |
int64 矢量 |
新维度的大小向量 |
从概念上讲,reshape 会先将数组展平为数据值的一维矢量,然后将此矢量提炼为新形状。输入参数是类型为 T 的任意数组、维度索引的编译时常量矢量,以及结果的维度大小的编译时常量矢量。dimension
矢量中的值(如果给出)必须是 T 的所有维度的排列;如果未给出,则默认为 {0, ..., rank - 1}
。dimensions
中的维度顺序是从循环嵌套中的变化最慢的维度(最主要的维度)到变化最快的维度(最次要的维度),该循环嵌套会将输入数组合并为单个维度。new_sizes
矢量决定了输出数组的大小。new_sizes
中索引 0 处的值是维度 0 的大小,索引 1 处的值是维度 1 的大小,依此类推。new_size
维度的乘积必须等于运算数维度大小的乘积。将展开前数组优化为由 new_sizes
定义的多维数组时,new_sizes
中的维度将按照从最慢变化(最主要)到变化最快(最小)的顺序排序。
例如,设 v 为一个包含 24 个元素的数组:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
{20, 21, 22}, {25, 26, 27},
{30, 31, 32}, {35, 36, 37},
{40, 41, 42}, {45, 46, 47} };
Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
{31, 41, 12}, {22, 32, 42},
{15, 25, 35}, {45, 16, 26},
{36, 46, 17}, {27, 37, 47} };
let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
{11, 21}, {31, 41},
{12, 22}, {32, 42} },
{ {15, 25}, {35, 45},
{16, 26}, {36, 46},
{17, 27}, {37, 47} } };
一种特殊情况是,调整形状可以将单元素数组转换为标量,反之亦然。例如,
Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };
倒转
另请参阅 XlaBuilder::Rev
。
Rev(operand, dimensions)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 的数组 |
dimensions |
ArraySlice<int64> |
要反转的维度 |
沿指定的 dimensions
反转 operand
数组中元素的顺序,生成形状相同的输出数组。多维索引处的运算数数组的每个元素都以转换后的索引存储在输出数组中。多维索引的转换方法是,对要反转的每个维度的索引进行反转(即,如果大小为 N 的维度是反转维度之一,则其索引 i 会转换为 N - 1 - i)。
Rev
运算的一个用途是在神经网络中的梯度计算期间沿着两个窗口维度逆转卷积权重数组。
RngNormal
另请参阅 XlaBuilder::RngNormal
。
使用按照 \(N(\mu, \sigma)\) 正态分布生成的随机数字构建给定形状的输出。参数 \(\mu\) 和 \(\sigma\)以及输出形状必须采用浮点元素类型。而且这些参数必须是标量值。
RngNormal(mu, sigma, shape)
参数 | 类型 | 语义 |
---|---|---|
mu |
XlaOp |
T 型标量,指定所生成数字的平均值 |
sigma |
XlaOp |
类型为 T 的标量,用于指定生成的 |
shape |
Shape |
类型为 T 的输出形状 |
RngUniform
另请参阅 XlaBuilder::RngUniform
。
使用在区间 \([a,b)\)内按照均匀分布生成的随机数字构建给定形状的输出。参数和输出元素类型必须是布尔类型、整数类型或浮点类型,并且类型必须一致。CPU 和 GPU 后端目前仅支持 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,参数需要是标量值。如果为 \(b <= a\) ,则结果由实现定义。
RngUniform(a, b, shape)
参数 | 类型 | 语义 |
---|---|---|
a |
XlaOp |
用于指定间隔下限的类型为 T 的标量 |
b |
XlaOp |
T 型标量,指定间隔上限 |
shape |
Shape |
T 型的输出形状 |
RngBitGenerator
使用指定的算法(或后端默认算法)生成一个形状为给定形状且填充有均匀随机位数的输出,并返回更新后的状态(与初始状态具有相同的形状)和生成的随机数据。
初始状态是当前随机数生成的初始状态。它以及所需的形状和有效值取决于所使用的算法。
输出一定是初始状态的确定性函数,但不保证在后端和不同编译器版本之间是确定性的。
RngBitGenerator(algorithm, key, shape)
参数 | 类型 | 语义 |
---|---|---|
algorithm |
RandomAlgorithm |
要使用的 PRNG 算法。 |
initial_state |
XlaOp |
PRNG 算法的初始状态。 |
shape |
Shape |
生成数据的输出形状。 |
algorithm
的可用值:
rng_default
:具有后端专用形状要求的后端专用算法。rng_three_fry
:基于计数器的 ThreeFry PRNG 算法。initial_state
形状为u64[2]
,包含任意值。Salmon et al. SC 2011. 并行随机数:只需三步即可完成。rng_philox
:Philox 算法,用于并行生成随机数。initial_state
形状为u64[3]
,可包含任意值。Salmon et al. SC 2011. 并行随机数:只需三步即可完成。
散点图
XLA 分散操作会生成一系列结果,即输入数组 operands
的值,其中使用 update_computation
将多个切片(在 scatter_indices
指定的索引处)更新为 updates
中的值序列。
另请参阅 XlaBuilder::Scatter
。
scatter(operands..., scatter_indices, updates..., update_computation,
index_vector_dim, update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims)
参数 | 类型 | 语义 |
---|---|---|
operands |
一系列 N 个 XlaOp |
需要分散到 T_0, ..., T_N 类型的 N 个数组。 |
scatter_indices |
XlaOp |
包含必须分散到的切片的起始索引的数组。 |
updates |
一系列 N 个 XlaOp |
类型为 T_0, ..., T_N 的 N 个数组。updates[i] 包含必须用于散射 operands[i] 的值。 |
update_computation |
XlaComputation |
用于组合输入数组中的现有值和分散期间的更新的计算。此计算的类型应为 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) 。 |
index_vector_dim |
int64 |
scatter_indices 中包含起始索引的维度。 |
update_window_dims |
ArraySlice<int64> |
updates 形状中的一组尺寸,属于窗口尺寸。 |
inserted_window_dims |
ArraySlice<int64> |
必须插入到 updates 形状中的窗口尺寸集。 |
scatter_dims_to_operand_dims |
ArraySlice<int64> |
从分散索引到操作数索引空间的维度映射。此数组会被解读为将 i 映射到 scatter_dims_to_operand_dims[i] 。它必须是一对一且总计。 |
indices_are_sorted |
bool |
是否保证调用方对索引进行排序。 |
unique_indices |
bool |
调用方是否保证索引是唯一的。 |
其中:
- N 必须大于或等于 1。
operands
[0
]、...、operands
[N-1
] 必须都具有相同的尺寸。updates
[0
]、...、updates
[N-1
] 必须具有相同的维度。- 如果为
N = 1
,则Collate(T)
为T
。 - 如果为
N > 1
,则Collate(T_0, ..., T_N)
是T
类型的N
元素的元组。
如果 index_vector_dim
等于 scatter_indices.rank
,我们会隐式地将 scatter_indices
视为具有尾随 1
维度。
我们将类型为 ArraySlice<int64>
的 update_scatter_dims
定义为 updates
形状中不属于 update_window_dims
的一组维度,按升序排列。
散点图的参数应遵循以下约束条件:
每个
updates
数组的秩都必须为update_window_dims.size + scatter_indices.rank - 1
。每个
updates
数组中维度i
的边界必须符合以下要求:- 如果
update_window_dims
中存在i
(即某些k
等于update_window_dims
[k
]),则在考虑inserted_window_dims
后,updates
中的维度i
的边界不得超出operand
的相应边界(即adjusted_window_bounds
[k
],其中adjusted_window_bounds
包含operand
的边界,其中索引为inserted_window_dims
的边界已移除)。 - 如果
update_scatter_dims
中存在i
(即某些k
等于update_scatter_dims
[k
]),则updates
中的维度i
的边界必须等于scatter_indices
的相应边界,并跳过index_vector_dim
(即,如果k
<index_vector_dim
,则跳过scatter_indices.shape.dims
[k
],否则为scatter_indices.shape.dims
[k+1
])。
- 如果
update_window_dims
必须按升序排列,不包含任何重复的维度编号,并且必须在[0, updates.rank)
范围内。inserted_window_dims
必须按升序排列,不包含任何重复的维度编号,并且必须在[0, operand.rank)
范围内。operand.rank
必须等于update_window_dims.size
和inserted_window_dims.size
的总和。“
scatter_dims_to_operand_dims.size
”必须等于scatter_indices.shape.dims
[index_vector_dim
],且其值必须在[0, operand.rank)
范围内。
对于每个 updates
数组中的给定索引 U
,必须应用此更新的相应 operands
数组中的相应索引 I
的计算方式如下:
- 让
G
= {U
[k
] fork
inupdate_scatter_dims
}。使用G
在scatter_indices
数组中查找索引矢量S
,以便S
[i
] =scatter_indices
[Combine(G
,i
)],其中 Combine(A, b) 将index_vector_dim
位置的 b 插入 A 中。 - 使用
scatter_dims_to_operand_dims
映射将S
分散,然后使用S
在operand
中创建索引S
in
。更正式的形式:S
如果k
<scatter_dims_to_operand_dims.size
,则in
[scatter_dims_to_operand_dims
[k
]] =S
[k
]。S
in
[_
] =0
,否则。
- 通过根据
inserted_window_dims
将索引散布在U
中的update_window_dims
,在每个operands
数组中创建索引W
in
。更正式地说:- 如果
k
在update_window_dims
中,则W
in
[window_dims_to_operand_dims
(k
)] =U
[k
],其中window_dims_to_operand_dims
是域为 [0
,update_window_dims.size
) 且范围为 [0
,operand.rank
) \inserted_window_dims
的单调函数。(例如,如果update_window_dims.size
为4
、operand.rank
为6
且inserted_window_dims
为 {0
,2
},则window_dims_to_operand_dims
为 {0
→1
,1
→3
,2
→4
,3
→5
})。 W
in
[_
] =0
,否则。
- 如果
I
为W
in
+S
in
,其中 + 表示按元素相加。
总的来说,分散操作可定义如下。
- 使用
operands
初始化output
,即对于operands
[J
] 数组中的所有索引J
和O
:output
[J
][O
] =operands
[J
][O
] - 对于
updates
[J
] 数组中的每个索引U
和operand
[J
] 数组中的对应索引O
,如果O
是output
的有效索引:(output
[0
][O
], ...,output
[N-1
][O
]) =update_computation
(output
[0
][O
], ...,output
[N-1
][O
],updates
[0
][U
], ...,updates
[N-1
][U
])
更新的应用顺序是不确定的。因此,当 updates
中的多个索引引用 operands
中的同一索引时,output
中的相应值将是非确定性的。
请注意,传入 update_computation
的第一个参数始终是 output
数组中的当前值,第二个参数始终是 updates
数组中的值。这一点尤其重要,尤其是在 update_computation
不可交换的情况下。
如果 indices_are_sorted
设置为 true,则 XLA 可以假定 scatter_indices
已由用户排序(按升序排序,在根据 scatter_dims_to_operand_dims
散布其值之后)。否则,语义就是实现定义的。
如果 unique_indices
设为 true,则 XLA 可以假定所有分散到的元素都是唯一的。因此,XLA 可以使用非原子操作。如果 unique_indices
设置为 true,并且要分散到的索引不唯一,则语义由实现定义。
通俗地说,散点操作可以被视为收集操作的反转,即散点操作会更新输入中由相应收集操作提取的元素。
如需详细了解非正式说明和示例,请参阅 Gather
下的“非正式说明”部分。
选择
另请参阅 XlaBuilder::Select
。
根据谓词数组的值,使用两个输入数组的元素构造输出数组。
Select(pred, on_true, on_false)
参数 | 类型 | 语义 |
---|---|---|
pred |
XlaOp |
类型为 PRED 的数组 |
on_true |
XlaOp |
类型为 T 的数组 |
on_false |
XlaOp |
类型为 T 的数组 |
数组 on_true
和 on_false
必须具有相同的形状。这也是输出数组的形状。数组 pred
的维度必须与 on_true
和 on_false
相同,并且采用 PRED
元素类型。
对于 pred
的每个元素 P
,如果 P
的值为 true
,则从 on_true
中获取输出数组的相应元素;如果 P
的值为 false
,则从 on_false
中获取相应元素。作为一种受限形式的广播,pred
可以是 PRED
类型的标量。在这种情况下,如果 pred
为 true
,则输出数组完全取自 on_true
;如果 pred
为 false
,则输出数组完全取自 on_false
。
使用非标量 pred
的示例:
let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
包含标量 pred
的示例:
let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
支持元组之间的选择。为此,元组被视为标量类型。如果 on_true
和 on_false
是元组(必须具有相同的形状!),则 pred
必须是类型为 PRED
的标量。
SelectAndScatter
另请参阅 XlaBuilder::SelectAndScatter
。
此操作可以视为一个复合操作,该操作首先对 operand
数组计算 ReduceWindow
以从每个窗口中选择一个元素,然后将 source
数组分散到所选元素的索引,以构建与运算数组具有相同形状的输出数组。二进制 select
函数用于通过对每个窗口应用该函数来从每个窗口中选择一个元素,并且在调用该函数时,第一个参数的索引矢量在字典顺序上小于第二个参数的索引矢量。如果选择了第一个参数,select
函数会返回 true
;如果选择了第二个参数,该函数会返回 false
,并且该函数必须保持传递性(即,如果 select(a, b)
和 select(b, c)
为 true
,则 select(a, c)
也为 true
),这样,所选元素就不会依赖于在给定窗口内遍历的元素顺序。
系统会对输出数组中的每个选定索引应用函数 scatter
。它接受两个标量参数:
- 输出数组中所选索引处的当前值
- 适用于所选索引的
source
中的散射值
它会组合这两个参数,并返回一个标量值,用于更新输出数组中所选索引的值。最初,输出数组的所有索引都设为 init_value
。
输出数组的形状与 operand
数组相同,并且 source
数组的形状必须与对 operand
数组应用 ReduceWindow
运算的结果相同。SelectAndScatter
可用于反向传播神经网络中池化层的梯度值。
SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
窗口滑动所依据的类型为 T 的数组 |
select |
XlaComputation |
类型为 T, T -> PRED 的二进制计算,适用于每个窗口中的所有元素;如果选择第一个参数,则返回 true ;如果选择第二个参数,则返回 false |
window_dimensions |
ArraySlice<int64> |
窗口维度值的整数数组 |
window_strides |
ArraySlice<int64> |
窗口步长值的整数数组 |
padding |
Padding |
窗口的填充类型(Padding::kSame 或 Padding::kValid) |
source |
XlaOp |
包含要散布的值的类型为 T 的数组 |
init_value |
XlaOp |
输出数组的初始值的类型为 T 的标量值 |
scatter |
XlaComputation |
类型为 T, T -> T 的二进制计算,用于将每个散射源元素应用于其目标元素 |
下图显示了使用 SelectAndScatter
的示例,其中 select
函数计算其参数中的最大值。请注意,当窗口重叠时(如图 (2) 所示),不同的窗口可能会多次选择 operand
数组的索引。在图中,值为 9 的元素被顶部的两个窗口(蓝色和红色)同时选中,二进制加法 scatter
函数会生成值为 8(2 + 6)的输出元素。
scatter
函数的求值顺序是任意的,并且可能是非确定性的。因此,scatter
函数不应对重新关联过于敏感。如需了解详情,请参阅 Reduce
上下文中关于关联性的讨论。
发送
另请参阅 XlaBuilder::Send
。
Send(operand, channel_handle)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要发送的数据(类型为 T 的数组) |
channel_handle |
ChannelHandle |
每个 send/recv 对的唯一标识符 |
将给定操作数数据发送到共享相同通道句柄的另一个计算中的 Recv
指令。不返回任何数据。
与 Recv
操作类似,Send
操作的客户端 API 表示同步通信,并在内部分解为 2 个 HLO 指令(Send
和 SendDone
),以实现异步数据传输。另请参阅 HloInstruction::CreateSend
和 HloInstruction::CreateSendDone
。
Send(HloInstruction operand, int64 channel_id)
启动将操作数异步传输到具有相同通道 ID 的 Recv
指令分配的资源。返回一个上下文,后续 SendDone
指令使用该上下文等待数据传输完成。上下文是 {操作数(形状)、请求标识符 (U32)} 的元组,只能由 SendDone
指令使用。
SendDone(HloInstruction context)
给定由 Send
指令创建的上下文,等待数据传输完成。该指令不会返回任何数据。
通道指令的调度
每个通道(Recv
、RecvDone
、Send
、SendDone
)的 4 条指令的执行顺序如下。
Recv
发生在Send
之前Send
发生在RecvDone
之前Recv
发生在RecvDone
之前Send
发生在SendDone
之前
当后端编译器为通过通道指令进行通信的每个计算生成线性时间表时,这些计算之间不得存在循环。例如,以下时间安排会导致死锁。
切片
另请参阅 XlaBuilder::Slice
。
切片会从输入数组中提取子数组。子数组与输入的排名相同,并且包含输入数组中边界框内的值,其中边界框的维度和索引作为参数提供给切片操作。
Slice(operand, start_indices, limit_indices, strides)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
类型为 T 的 N 维数组 |
start_indices |
ArraySlice<int64> |
由 N 个整数构成的列表,其中包含每个维度切片的起始索引。值必须大于或等于零。 |
limit_indices |
ArraySlice<int64> |
包含每个维度 slice 的结束索引(不含该索引)的 N 个整数的列表。每个值必须大于或等于维度的相应 start_indices 值,且小于或等于维度的大小。 |
strides |
ArraySlice<int64> |
一个包含 N 个整数的列表,用于确定 slice 的输入步长。该 slice 会选择维度 d 中的每个 strides[d] 元素。 |
一维示例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
{2.0, 3.0}
二维示例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
Slice(b, {2, 1}, {4, 3}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
排序
另请参阅 XlaBuilder::Sort
。
Sort(operands, comparator, dimension, is_stable)
参数 | 类型 | 语义 |
---|---|---|
operands |
ArraySlice<XlaOp> |
要排序的操作数。 |
comparator |
XlaComputation |
要使用的比较器计算。 |
dimension |
int64 |
排序所依据的维度。 |
is_stable |
bool |
是否应使用稳定排序。 |
如果只提供一个操作数:
如果操作数是秩为 1 的张量(数组),则结果为排序数组。如果您想将数组按升序排序,则比较器应执行小于比较。正式地,对数组进行排序后,它会为所有索引位置
i, j
保留i < j
,该值为comparator(value[i], value[j]) = comparator(value[j], value[i]) = false
或comparator(value[i], value[j]) = true
。如果运算数的秩较高,则会沿着所提供的维度对运算数进行排序。例如,对于 2 阶张量(矩阵),维度值
0
将单独对每列进行排序,维度值1
将对每行单独排序。如果未提供维度编号,则默认选择最后一个维度。对于要排序的维度,将采用与排名为 1 的维度相同的排序顺序。
如果提供了 n > 1
运算数:
所有
n
运算数都必须是维度相同的张量。张量的元素类型可能有所不同。所有运算数都会一起排序,而不是单独排序。从概念上讲,运算数会被视为元组。在检查每个运算数在索引位置
i
和j
的元素是否需要交换时,系统会使用2 * n
标量参数调用比较器,其中参数2 * k
对应于k-th
运算数在位置i
的值,参数2 * k + 1
对应于k-th
运算数在位置j
的值。因此,通常,比较器会将参数2 * k
和2 * k + 1
进行比较,并可能使用其他参数对作为平局决胜依据。结果是一个包含运算数的元组,这些运算数按排序顺序(按上述提供的维度划分)组成。元组的
i-th
运算数对应于 Sort 的i-th
运算数。
例如,如果有三个操作数 operand0 = [3, 1]
、operand1 = [42, 50]
和 operand2 = [-3.0, 1.1]
,并且比较器仅使用小于运算符比较 operand0
的值,则排序的输出是元组 ([1, 3], [50, 42], [1.1, -3.0])
。
如果将 is_stable
设置为 true,则可以保证排序是稳定的,也就是说,如果比较器认为某些元素相等,则会保留相等值的相对顺序。两个元素 e1
和 e2
只有在 comparator(e1, e2) = comparator(e2, e1) = false
时才相等。默认情况下,is_stable
设置为 false。
Transpose
另请参阅 tf.reshape
操作。
Transpose(operand)
参数 | 类型 | 语义 |
---|---|---|
operand |
XlaOp |
要转置的操作数。 |
permutation |
ArraySlice<int64> |
如何排列维度。 |
使用给定排列对操作数维度进行排列,即 ∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]
。
这与 Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) 相同。
TriangularSolve
另请参阅 XlaBuilder::TriangularSolve
。
通过正向或后退换元法求解包含下三角系数矩阵的线性方程组。沿前导维度广播,此例程会在给定 a
和 b
的情况下,为变量 x
解析矩阵系统 op(a) * x =
b
或 x * op(a) = b
之一,其中 op(a)
为 op(a) = a
、op(a) = Transpose(a)
或 op(a) = Conj(Transpose(a))
。
TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)
参数 | 类型 | 语义 |
---|---|---|
a |
XlaOp |
一个 > 2 且形状为 [..., M, M] 的复数或浮点类型的数组。 |
b |
XlaOp |
一个秩 > 2 且形状相同的数组,如果 left_side 为 true,则形状为 [..., M, K] ,否则为 [..., K, M] 。 |
left_side |
bool |
指示是解 op(a) * x = b (true ) 形式还是 x * op(a) = b (false ) 形式的系统。 |
lower |
bool |
是否使用 a 的上三角或下三角。 |
unit_diagonal |
bool |
如果为 true ,则假定 a 的主对角线元素为 1 ,且不会被访问。 |
transpose_a |
Transpose |
是否按原样使用 a 、对其转置或取其共轭转置。 |
系统仅从 a
的下三角形/上三角形读取输入数据,具体取决于 lower
的值。其他三角形的值会被忽略。输出数据会在同一三角形中返回;另一个三角形中的值由实现定义,可以是任何值。
如果 a
和 b
的秩大于 2,则将它们视为矩阵批次,其中除了次要 2 个维度之外,所有其他维度都是批次维度。a
和 b
必须具有相同的批次维度。
元组
另请参阅 XlaBuilder::Tuple
。
一个包含可变数量的数据句柄的元组,每个句柄都有自己的形状。
这类似于 C++ 中的 std::tuple
。从概念上讲:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
您可以通过 GetTupleElement
运算解构(访问)元组。
而
另请参阅 XlaBuilder::While
。
While(condition, body, init)
参数 | 类型 | 语义 |
---|---|---|
condition |
XlaComputation |
类型为 T -> PRED 的 XlaComputation,用于定义循环的终止条件。 |
body |
XlaComputation |
类型为 T -> T 的 XlaComputation,用于定义循环的正文。 |
init |
T |
condition 和 body 的参数的初始值。 |
顺序执行 body
,直到 condition
失败。除了下面列出的差异和限制外,这与许多其他语言中的典型 when 循环类似。
While
节点会返回T
类型的值,这是上次执行body
的结果。- 类型
T
的形状是静态确定的,并且所有迭代中的形状必须相同。
计算的 T 参数在第一次迭代中使用 init
值进行初始化,并在每次后续迭代中自动更新为 body
中的新结果。
While
节点的一个主要用例是实现在神经网络中重复执行训练。下面显示了简化的伪代码,其中包含表示计算的图表。代码可以在 while_test.cc
中找到。此示例中的 T
类型是一个 Tuple
,由用于迭代计数的 int32
和用于累加器的 vector[10]
组成。在 1000 次迭代中,循环不断向累加器添加一个常量矢量。
// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
iteration = result(0) + 1;
new_vector = result(1) + constant_vector[10];
result = {iteration, new_vector};
}