操作语义

以下介绍了 XlaBuilder 接口中定义的操作的语义。通常,这些操作会与 xla_data.proto 中的 RPC 接口中定义的操作一对一映射。

关于命名法的一则说明:XLA 处理的泛化数据类型是一个 N 维数组,其中包含某种统一类型(例如 32 位浮点数)的元素。在本文档中,数组用于表示任意维数的数组。为方便起见,特殊情况具有更具体的熟悉名称;例如,矢量是 1 维数组,矩阵是 2 维数组。

AfterAll

另请参阅 XlaBuilder::AfterAll

AfterAll 接受可变数量的令牌,并生成单个令牌。令牌是基元类型,可在有副作用的操作之间串行化以强制执行排序。AfterAll 可用作令牌的联接,用于在集合运算之后对操作进行排序。

AfterAll(operands)

参数 类型 语义
operands XlaOp 可变参数数量

AllGather

另请参阅 XlaBuilder::AllGather

跨副本执行串联。

AllGather(operand, all_gather_dim, shard_count, replica_group_ids, channel_id)

参数 类型 语义
operand XlaOp 要在各个副本之间串联的数组
all_gather_dim int64 串联维度
replica_groups int64 的矢量矢量 要联接的组
channel_id 可选 int64 用于跨模块通信的可选通道 ID
  • replica_groups 是一系列副本组,系统会在这些副本组之间执行串联操作(可以使用 ReplicaId 检索当前副本的副本 ID)。每个组中的副本顺序决定了其输入在结果中的顺序。replica_groups 必须为空(在这种情况下,所有副本都属于单个组,按 0N - 1 的顺序排列),或者包含与副本数量相同的元素。例如,replica_groups = {0, 2}, {1, 3} 会在副本 02 以及 13 之间执行串联操作。
  • shard_count 是每个副本组的大小。在 replica_groups 为空的情况下,我们需要这样做。
  • channel_id 用于跨模块通信:只有具有相同 channel_idall-gather 操作才能相互通信。

输出形状是输入形状,其中 all_gather_dim 被放大了 shard_count 倍。例如,如果有两个副本,并且运算数在两个副本上的值分别为 [1.0, 2.5][3.0, 5.25],那么在 all_gather_dim0 的情况下,此运算的输出值在两个副本上都为 [1.0, 2.5, 3.0, 5.25]

AllReduce

另请参阅 XlaBuilder::AllReduce

跨副本执行自定义计算。

AllReduce(operand, computation, replica_group_ids, channel_id)

参数 类型 语义
operand XlaOp 要在各个副本中进行求和的数组或数组的非空元组
computation XlaComputation 减排量计算
replica_groups int64 的矢量矢量 执行缩减操作的组
channel_id 可选 int64 用于跨模块通信的可选通道 ID
  • operand 是数组元组时,系统会对元组的每个元素执行全局求和。
  • replica_groups 是执行缩减操作的副本组的列表(可以使用 ReplicaId 检索当前副本的副本 ID)。replica_groups 必须为空(在这种情况下,所有副本都属于单个组),或者包含与副本数量相同的元素数。例如,replica_groups = {0, 2}, {1, 3} 会在副本 02 以及 13 之间执行求值。
  • channel_id 用于跨模块通信:只有具有相同 channel_idall-reduce 操作才能相互通信。

输出形状与输入形状相同。例如,如果有两个副本,并且运算数在两个副本上的值分别为 [1.0, 2.5][3.0, 5.25],则这两个副本中此运算和求和计算的输出值均为 [4.0, 7.75]。如果输入是元组,则输出也是元组。

计算 AllReduce 的结果需要从每个副本获取一个输入,因此,如果一个副本执行 AllReduce 节点的次数多于另一个副本,则前一个副本将永远等待。由于所有副本都运行相同的程序,因此发生这种情况的可能性并不多,但如果 while 循环的条件依赖于 infeed 中的数据,并且 infeed 中的数据导致 while 循环在一个副本上迭代次数比另一个副本多,就有可能发生这种情况。

AllToAll

另请参阅 XlaBuilder::AllToAll

AllToAll 是一种集合操作,用于将数据从所有核心发送到所有核心。该方法分两个阶段来实现:

  1. 散射阶段。在每个核心上,操作数会沿 split_dimensions 拆分为 split_count 个块,这些块会分散到所有核心,例如,第 i 个块会发送到第 i 个核心。
  2. 收集阶段。每个核心都会沿 concat_dimension 串联收到的块。

参与的核心可以通过以下方式进行配置:

  • replica_groups:每个 ReplicaGroup 都包含参与计算的副本 ID 的列表(可以使用 ReplicaId 检索当前副本的副本 ID)。AllToAll 将按指定顺序在子组内应用。例如,replica_groups = { {1,2,3}, {4,5,0} } 表示将在副本 {1, 2, 3} 内以及在收集阶段应用 AllToAll,并且收到的块将按 1、2、3 的相同顺序串联。然后,系统会在副本 4、5、0 中应用另一个 AllToAll,并且串联顺序也为 4、5、0。如果 replica_groups 为空,则所有副本都属于一个组,按其出现的串联顺序排列。

前提条件:

  • split_dimension 上的运算数的维度大小可被 split_count 整除。
  • 运算元的形状不是元组。

AllToAll(operand, split_dimension, concat_dimension, split_count, replica_groups)

参数 类型 语义
operand XlaOp n 维输入数组
split_dimension int64 一个介于 [0, n) 之间的值,用于命名操作数沿其拆分的数据维度
concat_dimension int64 一个介于 [0, n) 之间的值,用于指定沿哪个维度串联分块
split_count int64 参与此操作的核心数。如果 replica_groups 为空,则此值应为副本数量;否则,此值应等于每个组中的副本数量。
replica_groups ReplicaGroup 矢量 每个组都包含一个副本 ID 列表。

下面显示了 Alltoall 的示例。

XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);

在此示例中,有 4 个核心参与 Alltoall。在每个核心上,操作数会沿着维度 1 拆分为 4 个部分,因此每个部分的形状为 f32[4,4]。4 个部分会分散到所有核心。然后,每个核心都会按核心 0-4 的顺序沿维度 0 串联收到的部分。因此,每个核心上的输出形状为 f32[16,4]。

BatchNormGrad

如需详细了解该算法,请参阅 XlaBuilder::BatchNormGrad原始批量归一化论文

计算批处理正则化的梯度。

BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)

参数 类型 语义
operand XlaOp 要归一化的 n 维数组 (x)
scale XlaOp 一维数组 (γ)
mean XlaOp 一维数组 (μ)
variance XlaOp 一维数组 (σ2)
grad_output XlaOp 传递给 BatchNormTraining (y) 的渐变
epsilon float 艾普西隆值 (ϵ)
feature_index int64 operand 中特征维度的索引

对于特征维度中的每个特征(feature_indexoperand 中特征维度的索引),该运算会计算相对于所有其他维度的 operandoffsetscale 的梯度。feature_index 必须是 operand 中地图项维度的有效索引。

这三个梯度由以下公式定义(假设 4 维数组为 operand,特征维度索引为 l,批处理大小为 m,空间大小为 wh):

cl=1mwhmi=1wj=1hk=1(yijklxijklμlσ2l+ϵ)dl=1mwhmi=1wj=1hk=1yijklxijkl=γlσ2l+ϵ(yijkldlcl(xijklμl))γl=mi=1wj=1hk=1(yijklxijklμlσ2l+ϵ) βl=mi=1wj=1hk=1yijkl

输入 meanvariance 表示批处理和空间维度中的时刻值。

输出类型是三个句柄的元组:

输出 类型 语义
grad_operand XlaOp 相对于输入 operand 的梯度 (x)
grad_scale XlaOp 相对于输入 scale 的梯度 (γ)
grad_offset XlaOp 相对于输入 offset 的梯度(β)

BatchNormInference

如需详细了解该算法,请参阅 XlaBuilder::BatchNormInference原始批量归一化论文

对批量和空间维度中的数组进行归一化。

BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)

参数 类型 语义
operand XlaOp 要归一化的 n 维数组
scale XlaOp 1 维数组
offset XlaOp 1 维数组
mean XlaOp 1 维数组
variance XlaOp 1 维数组
epsilon float ε 值
feature_index int64 operand 中特征维度的索引

对于特征维度中的每个特征(feature_indexoperand 中特征维度的索引),该操作会计算所有其他维度的均值和方差,并使用均值和方差对 operand 中的每个元素进行归一化处理。feature_index 必须是 operand 中地图项维度的有效索引。

BatchNormInference 等同于调用 BatchNormTraining,而无需为每个批处理计算 meanvariance。它会改用输入 meanvariance 作为估算值。此运算的目的是减少推理延迟时间,因此得名为 BatchNormInference

输出是一个与输入 operand 形状相同的 n 维归一化数组。

BatchNormTraining

如需详细了解该算法,请参阅 XlaBuilder::BatchNormTrainingthe original batch normalization paper

对批量和空间维度中的数组进行归一化。

BatchNormTraining(operand, scale, offset, epsilon, feature_index)

参数 类型 语义
operand XlaOp 要归一化的 n 维数组 (x)
scale XlaOp 一维数组 (γ)
offset XlaOp 一维数组 (β)
epsilon float 艾普西隆值 (ϵ)
feature_index int64 operand 中特征维度的索引

对于特征维度中的每个特征(feature_indexoperand 中特征维度的索引),该操作会计算所有其他维度的均值和方差,并使用均值和方差对 operand 中的每个元素进行归一化处理。feature_index 必须是 operand 中地图项维度的有效索引。

对于 operand x 中的每个批次,如果该批次包含 m 元素,且空间维度的大小为 wh(假设 operand 是一个 4 维数组),则算法如下所示:

  • 计算特征维度中每个特征 l 的批处理均值 μlμl=1mwhmi=1wj=1hk=1xijkl

  • 计算批处理方差 σ2l: $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$

  • 标准化、缩放和移位:yijkl=γl(xijklμl)2σ2l+ϵ+βl

添加 epsilon 值(通常为小数)是为了避免除零错误。

输出类型是三个 XlaOp 的元组:

输出 类型 语义
output XlaOp 与输入 operand (y) 形状相同的 n 维数组
batch_mean XlaOp 一维数组 (μ)
batch_var XlaOp 一维数组 (σ2)

batch_meanbatch_var 是使用上述公式在批处理和空间维度上计算的瞬时值。

BitcastConvertType

另请参阅 XlaBuilder::BitcastConvertType

与 TensorFlow 中的 tf.bitcast 类似,执行从数据形状到目标形状的逐元素按位转换运算。输入和输出大小必须一致:例如,s32 元素通过位转换例程变为 f32 元素,一个 s32 元素将变为四个 s8 元素。位转换是作为低级转换实现的,因此具有不同浮点表示形式的机器将会给出不同的结果。

BitcastConvertType(operand, new_element_type)

参数 类型 语义
operand XlaOp 类型为 T 且维度为 D 的数组
new_element_type PrimitiveType 类型 U

除了最后一个维度(将按转换前后基元大小的比率而变化)外,运算数和目标形状的维度必须一致。

源和目标元素类型不得为元组。

对不同宽度的基元类型进行位转换

BitcastConvert HLO 指令支持输出元素类型 T' 的大小不等于输入元素 T 的大小的情况。由于整个操作在概念上是位转换,并且不会更改底层字节,因此输出元素的形状必须更改。对于 B = sizeof(T), B' = sizeof(T'),有两种可能的情况。

首先,当 B > B' 时,输出形状会获得大小为 B/B' 的新最小次元。例如:

  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)

对于有效标量,规则保持不变:

  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)

或者,对于 B' > B,该指令要求输入形状的最后一个逻辑维度等于 B'/B,并且此维度会在转换期间被舍弃:

  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)

请注意,不同位宽之间的转换不是按元素进行的。

广播

另请参阅 XlaBuilder::Broadcast

通过复制数组中的数据,向数组添加维度。

Broadcast(operand, broadcast_sizes)

参数 类型 语义
operand XlaOp 要复制的数组
broadcast_sizes ArraySlice<int64> 新维度的尺寸

新维度会插入到左侧,即如果 broadcast_sizes 的值为 {a0, ..., aN},并且操作数形状的维度为 {b0, ..., bM},则输出的形状的维度为 {a0, ..., aN, b0, ..., bM}

新维度会对运算数的副本进行编制索引,即

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

例如,如果 operand 是值为 2.0f 的标量 f32,并且 broadcast_sizes{2, 3},则结果将是形状为 f32[2, 3] 的数组,并且结果中的所有值均为 2.0f

BroadcastInDim

另请参阅 XlaBuilder::BroadcastInDim

通过复制数组中的数据来扩展数组的大小和维度数。

BroadcastInDim(operand, out_dim_size, broadcast_dimensions)

参数 类型 语义
operand XlaOp 要复制的数组
out_dim_size ArraySlice<int64> 目标形状的尺寸
broadcast_dimensions ArraySlice<int64> 操作数形状的每个维度对应于目标形状中的哪个维度

与广播类似,但允许在任何位置添加维度,并扩展大小为 1 的现有维度。

operand 会广播到 out_dim_size 描述的形状。broadcast_dimensions 会将 operand 的维度映射到目标形状的维度,即将运算数的第 i 个维度映射到输出形状的 broadcast_dimension[i] 个维度。operand 的维度必须为 1 或与其映射到的输出形状中的维度相同。系统会使用大小为 1 的维度填充其余维度。然后,退化维度广播会沿这些退化维度广播,以达到输出形状。广播页面详细介绍了这些语义。

致电

另请参阅 XlaBuilder::Call

使用给定参数调用计算。

Call(computation, args...)

参数 类型 语义
computation XlaComputation 使用任意类型的 N 个参数计算类型 T_0, T_1, ..., T_{N-1} -> S
args 一系列 N 个 XlaOp 任意类型的 N 个实参

args 的 arity 和类型必须与 computation 的参数相匹配。可以没有 args

CompositeCall

另请参阅 XlaBuilder::CompositeCall

封装由其他 StableHLO 运算组成的运算,接受输入和 composite_attributes 并生成结果。运算的语义由分解属性实现。复合运算可以替换为其分解,而无需更改程序语义。如果内嵌分解不提供相同的操作语义,请优先使用 custom_call。

版本字段(默认为 0)用于表示复合体的语义何时发生变化。

此运算以具有属性 is_composite=truekCall 形式实现。decomposition 字段由 computation 属性指定。前端属性存储前缀为 composite. 的其余属性。

CompositeCall 操作示例:

f32[] call(f32[] %cst), to_apply=%computation, is_composite=true,
frontend_attributes = {
  composite.name="foo.bar",
  composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},
  composite.version="1"
}

Call(computation, args..., name, composite_attributes, version)

参数 类型 语义
inputs XlaOp 值的变参数量
name string 复合体的名称
composite_attributes 可选 string 可选的字符串化属性字典
decomposition XlaComputation 使用任意类型的 N 个参数计算类型 T_0, T_1, ..., T_{N-1} -> S
version int64 将复合运算的语义从数字更新为版本

Cholesky

另请参阅 XlaBuilder::Cholesky

计算一组对称(赫尔米特)正定矩阵的 Cholesky 分解

Cholesky(a, lower)

参数 类型 语义
a XlaOp 一个复杂类型或浮点类型的数组,维度大于 2。
lower bool 是否使用 a 的上三角或下三角。

如果 lowertrue,则计算下三角矩阵 l,使 a=llT。如果 lowerfalse,则计算上三角矩阵 u,使a=uT.u

系统仅从 a 的下三角形/上三角形读取输入数据,具体取决于 lower 的值。系统会忽略另一个三角形中的值。输出数据会在同一三角形中返回;另一个三角形中的值由实现定义,可以是任何值。

如果 a 的维度超过 2 个,则 a 会被视为一批矩阵,其中除次要的 2 个维度外,所有其他维度都是批次维度。

如果 a 不是对称(赫尔米特)正定矩阵,则结果由实现定义。

限制取值范围

另请参阅 XlaBuilder::Clamp

将运算数限制在最小值和最大值之间的范围内。

Clamp(min, operand, max)

参数 类型 语义
min XlaOp 类型为 T 的数组
operand XlaOp 类型为 T 的数组
max XlaOp 类型为 T 的数组

给定运算数以及最小值和最大值,如果运算数在最小值和最大值之间,则返回运算数;如果运算数低于此范围,则返回最小值;如果运算数高于此范围,则返回最大值。即 clamp(a, x, b) = min(max(a, x), b)

这三个数组的形状必须相同。或者,作为广播的一种受限形式,min 和/或 max 可以是类型为 T 的标量。

使用标量 minmax 的示例:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

收起

另请参阅 XlaBuilder::Collapsetf.reshape 运算。

将数组的维度合并为一维。

Collapse(operand, dimensions)

参数 类型 语义
operand XlaOp 类型为 T 的数组
dimensions int64 矢量 T 维度的有序连续子集。

收起会将运算对象维度的给定子集替换为单个维度。输入参数是类型为 T 的任意数组,以及维度索引的编译时常量矢量。维度索引必须是 T 维度的连续子集,且按顺序(从低到高维度编号)排列。因此,{0, 1, 2}、{0, 1} 或 {1, 2} 都是有效的维度组合,但 {1, 0} 或 {0, 2} 无效。它们会被一个新维度取代,新维度会在维度序列中与被替换的维度处于相同的位置,并且新维度的大小等于原始维度大小的乘积。dimensions 中的维度编号越低,表示在用于收起这些维度的循环嵌套中,该维度的变化越慢(最主要),维度编号越高,表示该维度的变化越快(最次要)。如果需要更通用的收起排序,请参阅 tf.reshape 运算符。

例如,设 v 为一个包含 24 个元素的数组:

let v = f32[4x2x3] { { {10, 11, 12},  {15, 16, 17} },
{ {20, 21, 22},  {25, 26, 27} },
{ {30, 31, 32},  {35, 36, 37} },
{ {40, 41, 42},  {45, 46, 47} } };

// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};

// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };

// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };

CollectivePermute

另请参阅 XlaBuilder::CollectivePermute

CollectivePermute 是一种集合操作,用于跨副本发送和接收数据。

CollectivePermute(operand, source_target_pairs)

参数 类型 语义
operand XlaOp n 维输入数组
source_target_pairs <int64, int64> 矢量 一个包含 (source_replica_id, target_replica_id) 对的列表。对于每个对,运算数都会从源副本发送到目标副本。

请注意,source_target_pair 存在以下限制:

  • 任何两个对都不能具有相同的目标副本 ID,也不能具有相同的源副本 ID。
  • 如果某个副本 ID 不是任何对中的目标,则该副本上的输出是一个由 0 组成的张量,其形状与输入相同。

串联

另请参阅 XlaBuilder::ConcatInDim

串联用于从多个数组运算对象组合数组。该数组的维度数与每个输入数组运算数(必须具有相同的维度数)相同,并且包含按指定顺序的参数。

Concatenate(operands..., dimension)

参数 类型 语义
operands 一系列 N 个 XlaOp 类型为 T 且维度为 [L0, L1, ...] 的 N 个数组。要求 N >= 1。
dimension int64 一个介于 [0, N) 之间的值,用于指定要在 operands 之间串联的维度。

除了 dimension 之外,所有维度都必须相同。这是因为 XLA 不支持“带有空格”的数组。另请注意,0 维值无法串联(因为无法为串联发生的维度命名)。

一维示例:

Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}

二维示例:

let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}

示意图:

基于条件

另请参阅 XlaBuilder::Conditional

Conditional(pred, true_operand, true_computation, false_operand, false_computation)

参数 类型 语义
pred XlaOp 类型为 PRED 的标量
true_operand XlaOp 类型为 T0的参数
true_computation XlaComputation 类型为 T0S的 XlaComputation
false_operand XlaOp 类型为 T1的参数
false_computation XlaComputation 类型为 T1S的 XlaComputation

如果 predtrue,则执行 true_computation;如果 predfalse,则执行 false_computation,并返回结果。

true_computation 必须接受一个类型为 T0 的参数,并将通过 true_operand 调用,该参数也必须是同一类型。false_computation 必须接受一个类型为 T1 的参数,并将通过类型相同的 false_operand 调用。true_computationfalse_computation 返回值的类型必须相同。

请注意,系统只会执行 true_computationfalse_computation 中的一项,具体取决于 pred 的值。

Conditional(branch_index, branch_computations, branch_operands)

参数 类型 语义
branch_index XlaOp 类型为 S32 的标量
branch_computations 一系列 N 个 XlaComputation 类型为 T0S,T1S,...,TN1S的 XlaComputation
branch_operands 一系列 N 个 XlaOp 类型为 T0,T1,...,TN1的参数

执行 branch_computations[branch_index] 并返回结果。如果 branch_index 是一个小于 0 或大于等于 N 的 S32,则会将 branch_computations[N-1] 作为默认分支进行执行。

每个 branch_computations[b] 都必须接受一个类型为 Tb 的参数,并且将通过 branch_operands[b] 调用,该参数也必须为相同类型。每个 branch_computations[b] 的返回值类型必须相同。

请注意,系统只会根据 branch_index 的值执行其中一个 branch_computations

Conv(卷积)

另请参阅 XlaBuilder::Conv

与 ConvWithGeneralPadding 相同,但以简写方式指定填充,即 SAME 或 VALID。SAME 填充会使用零填充输入 (lhs),以便在不考虑步长的情况下,输出与输入具有相同的形状。有效填充只是表示不填充。

ConvWithGeneralPadding(卷积)

另请参阅 XlaBuilder::ConvWithGeneralPadding

计算神经网络中使用的卷积类型。在这里,卷积可以被视为在 n 维基准区域中移动的 n 维窗口,并且系统会针对窗口的每个可能位置执行计算。

参数 类型 语义
lhs XlaOp 输入的(n+2)维数组
rhs XlaOp 核权重的 (n+2) 维数组
window_strides ArraySlice<int64> 内核步长的 n 维数组
padding ArraySlice< pair<int64,int64>> 一个包含(低、高)内边距的 n 维数组
lhs_dilation ArraySlice<int64> n-d 左侧边界膨胀因子数组
rhs_dilation ArraySlice<int64> n-d 右侧边界膨胀因子数组
feature_group_count int64 特征组的数量
batch_group_count int64 批处理组的数量

设 n 为空间维度数。lhs 参数是一个描述基底区域的(n+2)维数组。这称为输入,尽管 rhs 当然也是输入。在神经网络中,这些是输入激活。n+2 个维度按以下顺序排列:

  • batch:此维度中的每个坐标都代表一个要执行卷积的独立输入。
  • z/depth/features:基准区域中的每个 (y,x) 位置都与一个矢量相关联,该矢量会进入此维度。
  • spatial_dims:描述 n 空间尺寸,用于定义窗口移动所跨越的基本区域。

rhs 参数是一个 (n+2) 维数组,用于描述卷积滤波器/核/窗口。维度按以下顺序排列:

  • output-z:输出的 z 维度。
  • input-z:此维度的大小乘以 feature_group_count 应等于 lhs 中的 z 维度的大小。
  • spatial_dims:描述用于定义在基准区域中移动的 n 维窗口的 n 空间维度。

window_strides 参数用于指定空间维度中卷积窗口的步长。例如,如果第一个空间维度的步长为 3,则窗口只能放置在第一个空间索引可被 3 整除的坐标处。

padding 参数用于指定要应用于基准区域的零内边距量。填充量可以为负值,负值填充的绝对值表示在执行卷积之前要从指定维度中移除的元素数量。padding[0] 用于指定尺寸 y 的内边距,padding[1] 用于指定尺寸 x 的内边距。每个对的第一个元素是内边距下限,第二个元素是内边距上限。低内边距会应用在索引较低的方向,而高内边距会应用在索引较高的方向。例如,如果 padding[1](2,3),则第二个空间维度将在左侧添加 2 个零,在右侧添加 3 个零作为内边距。使用填充等同于在进行卷积之前将这些相同的零值插入输入 (lhs)。

lhs_dilationrhs_dilation 参数分别指定要应用于每个空间维度的左侧和右侧的膨胀因子。如果空间维度的膨胀系数为 d,则会在该维度中的每个条目之间隐式放置 d-1 个空洞,从而增加数组的大小。这些空洞会填充无操作值,对于卷积,这意味着填充零值。

对右侧进行膨胀也称为 atrous 卷积。如需了解详情,请参阅 tf.nn.atrous_conv2d。左侧的膨胀也称为转置卷积。如需了解详情,请参阅 tf.nn.conv2d_transpose

feature_group_count 参数(默认值为 1)可用于分组卷积。feature_group_count 需要同时是输入特征维度和输出特征维度的除数。如果 feature_group_count 大于 1,则表示在概念上,输入和输出特征维度以及 rhs 输出特征维度会均匀地拆分为许多 feature_group_count 组,每个组由连续的特征子序列组成。rhs 的输入特征维度需要等于 lhs 输入特征维度除以 feature_group_count(因此它已经具有一组输入特征的大小)。将第 i 组组合使用,可为许多单独的卷积计算 feature_group_count。这些卷积的结果会在输出特征维度中串联在一起。

对于深度卷积,feature_group_count 参数将设置为输入特征维度,并且滤镜将从 [filter_height, filter_width, in_channels, channel_multiplier] 重塑为 [filter_height, filter_width, 1, in_channels * channel_multiplier]。如需了解详情,请参阅 tf.nn.depthwise_conv2d

在反向传播期间,batch_group_count(默认值为 1)参数可用于分组过滤器。batch_group_count 需要是 lhs(输入)批次维度的大小的除数。如果 batch_group_count 大于 1,则表示输出批处理维度应为大小为 input batch / batch_group_count 的向量。batch_group_count 必须是输出特征大小的分母。

输出形状具有以下维度,按以下顺序排列:

  • batch:此维度的大小乘以 batch_group_count 应等于 lhs 中的 batch 维度的大小。
  • z:与内核上的 output-z 相同大小 (rhs)。
  • spatial_dims:对于卷积窗口的每个有效放置,都有一个值。

上图展示了 batch_group_count 字段的运作方式。实际上,我们会将每个 lhs 批处理切片为 batch_group_count 组,并对输出特征执行相同的操作。然后,对于这些组中的每个组,我们都会进行对偶卷积,并沿输出特征维度串联输出。所有其他维度(地图项和空间)的操作语义保持不变。

卷积窗口的有效放置由步长和填充后的基准区域大小决定。

为了说明卷积的运作方式,我们来考虑一下二维卷积,并在输出中选择一些固定的 batchzyx 坐标。然后,(y,x) 是窗口在基准区域内的一个角落的位置(例如左上角,具体取决于您对空间尺寸的解读方式)。现在,我们有一个从基准区域获取的二维窗口,其中每个二维点都与一个一维矢量相关联,因此我们得到了一个三维盒子。在卷积核中,由于我们固定了输出坐标 z,因此我们还拥有一个 3D 盒子。这两个框具有相同的维度,因此我们可以对这两个框之间的元素级乘积求和(类似于点积)。这就是输出值。

请注意,如果 output-z 为5,则窗口的每个位置都会在输出的 z 维度中生成 5 个值。这些值的不同之处在于所使用的卷积核的部分 - 每个 output-z 坐标都有一个单独的 3D 值盒。因此,您可以将其视为 5 个单独的卷积,每个卷积都有不同的滤镜。

以下是带有填充和步长的 2D 卷积的伪代码:

for (b, oz, oy, ox) {  // output coordinates
  value = 0;
  for (iz, ky, kx) {  // kernel coordinates and input z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if ((iy, ix) inside the base area considered without padding) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

另请参阅 XlaBuilder::ConvertElementType

与 C++ 中的元素级 static_cast 类似,执行从数据形状到目标形状的元素级转换操作。维度必须匹配,并且转换是按元素进行的;例如,s32 元素通过 s32f32 转换例程变为 f32 元素。

ConvertElementType(operand, new_element_type)

参数 类型 语义
operand XlaOp 类型为 T 且维度为 D 的数组
new_element_type PrimitiveType 类型 U

运算数和目标形状的维度必须匹配。源和目标元素类型不得为元组。

T=s32 转换为 U=f32 等转换将执行标准化整数转换为浮点值转换例程,例如向最近偶数舍入。

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

使用求和计算执行 AllReduce

CustomCall

另请参阅 XlaBuilder::CustomCall

在计算中调用用户提供的函数。

CustomCall(target_name, args..., shape)

参数 类型 语义
target_name string 函数的名称。系统会发出一个以此符号名称为目标的调用指令。
args 一系列 N 个 XlaOp 任意类型的 N 个参数,将传递给函数。
shape Shape 函数的输出形状

无论参数的个数或类型如何,函数签名都是相同的:

extern "C" void target_name(void* out, void** in);

例如,如果 CustomCall 的使用方式如下:

let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };

CustomCall("myfunc", {x, y}, f32[3x3])

下面是 myfunc 实现的示例:

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

用户提供的函数不得有副作用,并且其执行必须是幂等的。

另请参阅 XlaBuilder::Dot

Dot(lhs, rhs)

参数 类型 语义
lhs XlaOp 类型为 T 的数组
rhs XlaOp 类型为 T 的数组

此运算的确切语义取决于运算元的阶数:

输入 输出 语义
矢量 [n] dot 矢量 [n] 标量 矢量点积
矩阵 [m x k] dot 向量 [k] 矢量 [m] 矩阵-矢量乘法
矩阵 [m x k] dot 矩阵 [k x n] 矩阵 [m x n] 矩阵-矩阵乘法

该运算会对 lhs 的第二个维度(如果只有 1 个维度,则为第一个维度)和 rhs 的第一个维度执行乘积求和。这些是“收缩”的维度。lhsrhs 的收缩维度必须相同。在实践中,它可用于执行向量之间的点积、向量/矩阵乘法或矩阵/矩阵乘法。

DotGeneral

另请参阅 XlaBuilder::DotGeneral

DotGeneral(lhs, rhs, dimension_numbers)

参数 类型 语义
lhs XlaOp 类型为 T 的数组
rhs XlaOp 类型为 T 的数组
dimension_numbers DotDimensionNumbers 合同和批量维度编号

与 Dot 类似,但允许为 lhsrhs 指定收缩和批量维度编号。

DotDimensionNumbers 字段 类型 语义
lhs_contracting_dimensions repeated int64 lhs 缩减维度编号
rhs_contracting_dimensions repeated int64 rhs 缩减维度编号
lhs_batch_dimensions repeated int64 lhs 批量维度编号
rhs_batch_dimensions repeated int64 rhs 批量维度编号

DotGeneral 会对 dimension_numbers 中指定的收缩维度进行产品求和。

lhsrhs 中的关联收缩维度编号不必相同,但必须具有相同的维度大小。

缩减维度编号的示例:

lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }

lhsrhs 中关联的批次维度编号必须具有相同的维度大小。

包含批量维度数的示例(批量大小为 2,2x2 矩阵):

lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }

rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
输入 输出 语义
[b0, m, k] dot [b0, k, n] [b0, m, n] 批量 matmul
[b0, b1, m, k] dot [b0, b1, k, n] [b0, b1, m, n] 批量 matmul

因此,生成的维度编号从批处理维度开始,然后是 lhs 非收缩/非批处理维度,最后是 rhs 非收缩/非批处理维度。

DynamicSlice

另请参阅 XlaBuilder::DynamicSlice

DynamicSlice 会从动态 start_indices 的输入数组中提取子数组。每个维度中的 slice 的大小会在 size_indices 中传递,其中指定了每个维度中不含边界值的 slice 间隔的结束点:[start, start + size)。start_indices 的形状必须为 1 维,且维度大小等于 operand 的维度数。

DynamicSlice(operand, start_indices, size_indices)

参数 类型 语义
operand XlaOp 类型为 T 的 N 维数组
start_indices 一系列 N 个 XlaOp 包含每个维度 slice 的起始索引的 N 个标量整数的列表。值必须大于或等于零。
size_indices ArraySlice<int64> 包含每个维度的 slice 大小的 N 个整数的列表。每个值都必须严格大于零,并且 start + size 必须小于或等于维度的大小,以避免模除维度大小的封装。

有效的 slice 索引是通过在执行 slice 之前对 [1, N) 中的每个索引 i 应用以下转换计算得出的:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])

这可确保提取的 slice 始终相对于操作数数组在边界内。如果切片在应用转换之前在边界内,则转换无效。

一维示例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}

二维示例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0,  8.0},
{10.0, 11.0} }

DynamicUpdateSlice

另请参阅 XlaBuilder::DynamicUpdateSlice

DynamicUpdateSlice 会生成一个结果,即输入数组 operand 的值,并在 start_indices 处覆盖切片 updateupdate 的形状决定了要更新的结果的子数组的形状。start_indices 的形状必须为 1 维,且维度大小等于 operand 的维度数。

DynamicUpdateSlice(operand, update, start_indices)

参数 类型 语义
operand XlaOp 类型为 T 的 N 维数组
update XlaOp 包含 slice 更新的类型为 T 的 N 维数组。更新形状的每个维度都必须严格大于零,并且 start + update 必须小于或等于每个维度的运算数大小,以避免生成超出范围的更新索引。
start_indices 一系列 N 个 XlaOp 包含每个维度 slice 的起始索引的 N 个标量整数的列表。值必须大于或等于零。

有效的 slice 索引是通过在执行 slice 之前对 [1, N) 中的每个索引 i 应用以下转换计算得出的:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])

这可确保更新后的 slice 始终相对于操作数数组在边界内。如果切片在应用转换之前在边界内,则转换无效。

一维示例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}

二维示例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0,  13.0},
{14.0,  15.0},
{16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
{ {0.0,  1.0,  2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }

元素级二进制算术运算

另请参阅 XlaBuilder::Add

支持一组元素级二进制算术运算。

Op(lhs, rhs)

其中 Op 为以下各项之一:Add(加法)、Sub(减法)、Mul(乘法)、Div(除法)、Pow(乘方)、Rem(余数)、Max(最大值)、Min(最小值)、And(逻辑 AND)、Or(逻辑 OR)、Xor(逻辑 XOR)、ShiftLeft(向左移位)、ShiftRightArithmetic(向右算术移位)、ShiftRightLogical(逻辑向右移位)、Atan2(2 个参数的反正切函数)或 Complex(将实部和虚部组合成复数)

参数 类型 语义
lhs XlaOp 左侧运算数:类型为 T 的数组
rhs XlaOp 右侧操作数:类型为 T 的数组

参数的形状必须相似或兼容。如需了解形状兼容的含义,请参阅广播文档。操作的结果的形状是广播两个输入数组的结果。在此变体中,支持不同秩数数组之间的运算,除非其中一个运算数是标量。

OpRem 时,结果的符号取自被除数,并且结果的绝对值始终小于除数的绝对值。

整数除法溢出(对零进行有符号/无符号除法/求余或对 INT_SMIN 进行有符号除法/求余 -1)会产生实现定义的值。

这些运算有支持不同维度广播的替代变体:

Op(lhs, rhs, broadcast_dimensions)

其中,Op 与上文相同。此操作变体应用于不同秩数的数组之间的算术运算(例如向向量添加矩阵)。

额外的 broadcast_dimensions 运算数是一个整数切片,用于将低维运算数的维数扩展到高维运算数的维数。broadcast_dimensions 会将低维形状的维度映射到高维形状的维度。展开形状的未映射维度会填充大小为 1 的维度。然后,退化维度广播会沿这些退化维度广播形状,以使两个运算数的形状相等。广播页面详细介绍了这些语义。

元素级比较运算

另请参阅 XlaBuilder::Eq

支持一组标准的元素级二进制比较运算。请注意,比较浮点类型时,系统会应用标准 IEEE 754 浮点比较语义。

Op(lhs, rhs)

其中,OpEq(等于)、Ne(不等于)、Ge(大于或等于)、Gt(大于)、Le(小于或等于)、Lt(小于)之一。另一组运算符(EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder)提供相同的功能,但它们还支持对浮点数进行总排序,方法是强制执行 -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN。

参数 类型 语义
lhs XlaOp 左侧运算数:类型为 T 的数组
rhs XlaOp 右侧操作数:类型为 T 的数组

参数的形状必须相似或兼容。如需了解形状兼容的含义,请参阅广播文档。操作的结果的形状是将两个元素类型为 PRED 的输入数组广播的结果。在此变体中,支持不同秩数数组之间的运算,除非其中一个运算数是标量。

这些运算有支持不同维度广播的替代变体:

Op(lhs, rhs, broadcast_dimensions)

其中,Op 与上文相同。此运算变体应用于不同秩的数组之间的比较运算(例如向向量添加矩阵)。

额外的 broadcast_dimensions 运算数是一个整数 slice,用于指定用于广播运算数的维度。广播页面详细介绍了这些语义。

元素级单元函数

XlaBuilder 支持以下按元素的单元函数:

Abs(operand) 元素级绝对值 x -> |x|

Cbrt(operand) 元素级立方根运算 x -> cbrt(x)

Ceil(operand) 按元素取上限 x -> ⌈x⌉

Clz(operand) 按元素统计前导零数。

Cos(operand) 按元素计算的余弦 x -> cos(x)

Erf(operand) 元素级误差函数 x -> erf(x),其中

erf(x)=2πx0et2dt

Exp(operand) 按元素计算的自然指数 x -> e^x

Expm1(operand) 按元素计算自然指数减一 x -> e^x - 1

Floor(operand) 按元素取整 x -> ⌊x⌋

Imag(operand) 复杂(或实数)形状的元素级虚部。x -> imag(x)。如果操作数是浮点类型,则返回 0。

IsFinite(operand) 用于测试 operand 的每个元素是否有限,即不是正无穷大、负无穷大,也不是 NaN。返回一个与输入相同形状的 PRED 值数组,其中每个元素都为 true,且仅当相应的输入元素为有限值时才为 true

Log(operand) 按元素计算的自然对数 x -> ln(x)

Log1p(operand) 按元素偏移的自然对数 x -> ln(1+x)

Logistic(operand) 元素级逻辑函数计算 x -> logistic(x)

Neg(operand) 按元素取反 x -> -x

Not(operand) 元素级逻辑非 x -> !(x)

PopulationCount(operand) 计算 operand 的每个元素中设置的位数。

Real(operand) 复杂(或实数)形状的元素级实部。x -> real(x)。如果操作数是浮点类型,则返回相同的值。

Round(operand) 按元素舍入,平分时向远离 0 的方向舍入。

RoundNearestEven(operand) 按元素舍入,舍入到最接近的偶数。

Rsqrt(operand) 对平方根运算 x -> 1.0 / sqrt(x) 执行元素级反向运算。

Sign(operand) 元素级符号运算 x -> sgn(x),其中

sgn(x)={1x<00x=0NaNx=NaN+0x=+01x>0

使用元素类型 operand 的比较运算符。

Sin(operand) 按元素计算的正弦 x -> sin(x)

Sqrt(operand) 元素级平方根运算 x -> sqrt(x)

Tan(operand) 元素级导数 x -> tan(x)

Tanh(operand) 按元素计算的双曲正切 x -> tanh(x)

参数 类型 语义
operand XlaOp 函数的操作数

该函数会应用于 operand 数组中的每个元素,从而生成形状相同的数组。operand 可以是标量(0 维)。

Fft

XLA FFT 运算针对实数和复数输入/输出实现正向和反向傅里叶转换。支持对最多 3 个轴进行多维 FFT。

另请参阅 XlaBuilder::Fft

参数 类型 语义
operand XlaOp 要进行傅里叶转换的数组。
fft_type FftType 请参见下表。
fft_length ArraySlice<int64> 要转换的轴的时间域长度。由于 RFFT(fft_length=[16]) 的输出形状与 RFFT(fft_length=[17]) 相同,因此 IRFFT 尤其需要执行此操作,以便将最内侧轴调整为合适大小。
FftType 语义
FFT 正向复杂数到复杂数 FFT。形状保持不变。
IFFT 复杂数到复杂数的反 FFT。形状保持不变。
RFFT 正向实数到复数 FFT。如果 fft_length[-1] 为非零值,则最内侧轴的形状会缩减为 fft_length[-1] // 2 + 1,从而省略转换信号超出奈奎频率的反向共轭部分。
IRFFT 实数到复数的 FFT 的逆运算(即接受复数,返回实数)。如果 fft_length[-1] 为非零值,则最内侧轴的形状会扩展为 fft_length[-1],从 1fft_length[-1] // 2 + 1 条目的反共轭推断出转换信号超出奈奎频率的部分。

多维 FFT

如果提供多个 fft_length,则相当于对最内侧的每个轴应用一系列 FFT 操作。请注意,对于实数到复数和复数到实数的情况,最内侧轴转换(实际上)会先执行(RFFT;IRFFT 为最后),因此最内侧轴会发生尺寸变化。然后,其他轴转换将是复杂->复杂。

实现细节

CPU FFT 由 Eigen 的 TensorFFT 提供支持。GPU FFT 使用 cuFFT。

Gather

XLA 汇总操作会将输入数组的多个切片(每个切片的运行时偏移可能不同)缝合在一起。

一般语义学

另请参阅 XlaBuilder::Gather。 如需更直观的说明,请参阅下文中的“非正式说明”部分。

gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)

参数 类型 语义
operand XlaOp 我们要从中收集数据的数组。
start_indices XlaOp 包含我们收集的切片的起始索引的数组。
index_vector_dim int64 start_indices 中“包含”起始索引的维度。如需了解详情,请参阅下文。
offset_dims ArraySlice<int64> 输出形状中的一组维度,用于偏移到从运算数切片得到的数组。
slice_sizes ArraySlice<int64> slice_sizes[i] 是维度 i 上的 slice 的边界。
collapsed_slice_dims ArraySlice<int64> 每个 Slice 中被收起的一组维度。这些维度必须为 1。
start_index_map ArraySlice<int64> 一个映射,用于描述如何将 start_indices 中的索引映射到操作数中的有效索引。
indices_are_sorted bool 是否保证索引由调用方进行排序。

为方便起见,我们将输出数组(而非 offset_dims)中的维度标记为 batch_dims

输出是一个维度为 batch_dims.size + offset_dims.size 的数组。

operand.rank 必须等于 offset_dims.sizecollapsed_slice_dims.size 的总和。此外,slice_sizes.size 必须等于 operand.rank

如果 index_vector_dim 等于 start_indices.rank,我们会隐式地将 start_indices 视为具有尾随 1 维度(即,如果 start_indices 的形状为 [6,7]index_vector_dim2,则我们会隐式地将 start_indices 的形状视为 [6,7,1])。

沿着维度 i 的输出数组的边界按如下方式计算:

  1. 如果 batch_dims 中存在 i(即对于某些 k,等于 batch_dims[k]),则我们会从 start_indices.shape 中选择相应的维度边界,跳过 index_vector_dim(即如果 k < index_vector_dim,则选择 start_indices.shape.dims[k];否则,选择 start_indices.shape.dims[k+1])。

  2. 如果 i 存在于 offset_dims 中(即对于某个 k,等于 offset_dims[k]),则我们会在考虑 collapsed_slice_dims 后从 slice_sizes 中选择相应的边界(即我们选择 adjusted_slice_sizes[k],其中 adjusted_slice_sizes 是移除了索引 collapsed_slice_dims 的边界的 slice_sizes)。

正式地,与给定输出索引 Out 对应的操作数索引 In 的计算方式如下:

  1. G = { Out[k] for k in batch_dims }。使用 G 切片出矢量 S,使得 S[i] = start_indices[Combine(G, i)],其中 Combine(A, b) 会将 b 插入到 A 的 index_vector_dim 位置。请注意,即使 G 为空,此操作也是定义良好的:如果 G 为空,则 S = start_indices

  2. 使用 start_index_mapS 分散到 operand 中,然后使用 Soperand 中创建起始索引 Sin。更具体地说:

    1. 如果 k < start_index_map.size,则 Sin[start_index_map[k]] = S[k]。

    2. 否则 Sin[_] = 0

  3. 通过根据 collapsed_slice_dims 集在 Out 中的偏移维度上分散索引,在 operand 中创建索引 Oin。更具体地说:

    1. O如果 k < offset_dims.sizeremapped_offset_dims 在下文中定义),则 in[remapped_offset_dims(k)] = Out[offset_dims[k]]。

    2. 否则 Oin[_] = 0

  4. InOin + Sin,其中 + 表示按元素相加。

remapped_offset_dims 是一个单调函数,其域为 [0, offset_dims.size),范围为 [0, operand.rank) \ collapsed_slice_dims。因此,例如,如果 offset_dims.size4operand.rank6collapsed_slice_dims 为 {0, 2},则 remapped_offset_dims 为 {01, 13, 24, 35}。

如果 indices_are_sorted 设置为 true,则 XLA 可以假定 start_indices 已由用户排序(按升序排列,根据 start_index_map 分散其值之后)。如果没有,则语义由实现定义。

非正式说明和示例

非正式地,输出数组中的每个索引 Out 都对应于运算数数组中的元素 E,计算方式如下:

  • 我们使用 Out 中的批处理维度从 start_indices 中查找起始索引。

  • 我们使用 start_index_map 将起始索引(大小可能小于 operand.rank)映射到 operand 中的“完整”起始索引。

  • 我们使用完整的起始索引动态切片出大小为 slice_sizes 的 slice。

  • 我们通过收起 collapsed_slice_dims 维度来调整 slice 的形状。由于所有已收起的 slice 维度都必须具有 1 的边界,因此此重塑始终是合法的。

  • 我们使用 Out 中的偏移量维度对此 slice 进行编入,以获取与输出索引 Out 对应的输入元素 E

在以下所有示例中,index_vector_dim 均设置为 start_indices.rank-1index_vector_dim 的更有趣的值不会从根本上改变操作,但会使可视化表示更为繁琐。

为了直观地了解上述所有内容如何协同工作,我们来看一个示例,该示例会从 [16,11] 数组中收集 5 个形状为 [8,6] 的 slice。[16,11] 数组中 slice 的位置可以表示为形状为 S64[2] 的索引矢量,因此这组 5 个位置可以表示为 S64[5,2] 数组。

然后,汇总操作的行为可以描述为一个索引转换,该转换接受 [G,O0,O1](输出形状中的索引),并按如下方式将其映射到输入数组中的元素:

我们首先使用 G 从汇总索引数组中选择一个 (X,Y) 矢量。然后,输出数组中索引为 [G,O0,O1] 的元素就是输入数组中索引为 [X+O0,Y+O1] 的元素。

slice_sizes[8,6],它决定了 O0 和 O1 的范围,而这反过来又决定了 slice 的边界。

此汇总操作充当批量动态 slice,其中 G 是批量维度。

汇总索引可以是多维的。例如,上述示例的更通用版本使用形状为 [4,5,2] 的“gather 索引”数组,会按如下方式转换索引:

同样,这会充当批量动态 slice G0,并将 G1 用作批量维度。切片大小仍为 [8,6]

XLA 中的汇总操作会以以下方式对上述非正式语义进行泛化:

  1. 我们可以配置输出形状中的哪些维度是偏移维度(上例中包含 O0O1 的维度)。输出批次维度(上例中包含 G0G1 的维度)定义为非偏移维度的输出维度。

  2. 输出形状中明确存在的输出偏移维度数量可能小于输入维度数量。这些“缺失”的维度(明确列为 collapsed_slice_dims)的 slice 大小必须为 1。由于它们的 slice 大小为 1,因此唯一有效的索引是 0,并且省略它们不会引入歧义。

  3. 从“Gather Indices”(上例中的 (X, Y))数组中提取的切片可能包含的元素数量少于输入数组的维度数,并且显式映射决定了索引应如何展开,以使其维度数与输入相同。

最后一个示例中,我们使用 (2) 和 (3) 来实现 tf.gather_nd

G0G1 用于像往常一样从汇总索引数组中切片出起始索引,不同之处在于起始索引只有一个元素 X。同样,只有一个输出偏移量索引,其值为 O0。不过,在用作输入数组的索引之前,这些索引会根据“Gather Index Mapping”(正式说明中的 start_index_map)和“Offset Mapping”(正式说明中的 remapped_offset_dims)分别展开为 [X,0] 和 [0,O0],总和为 [X,O0]。换句话说,输出索引 [G0,G1,O0] 会映射到输入索引 [GatherIndices[G0,G1,0],O0],这为我们提供了 tf.gather_nd 的语义。

在本例中,slice_sizes[1,11]。直观地讲,这意味着汇总索引数组中的每个索引 X 都会选择一整行,而结果是所有这些行的串联。

GetDimensionSize

另请参阅 XlaBuilder::GetDimensionSize

返回运算数的给定维度的大小。运算数必须为数组形式。

GetDimensionSize(operand, dimension)

参数 类型 语义
operand XlaOp n 维输入数组
dimension int64 一个介于 [0, n) 之间的值,用于指定维度

SetDimensionSize

另请参阅 XlaBuilder::SetDimensionSize

设置 XlaOp 给定维度的动态大小。运算数必须为数组形式。

SetDimensionSize(operand, size, dimension)

参数 类型 语义
operand XlaOp n 维输入数组。
size XlaOp int32,表示运行时动态大小。
dimension int64 一个介于 [0, n) 之间的值,用于指定维度。

因此,传递运算数,并由编译器跟踪动态维度。

下游缩减运算会忽略填充的值。

let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;

// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);

// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);

// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);

GetTupleElement

另请参阅 XlaBuilder::GetTupleElement

指向具有编译时常量值的元组中的索引。

该值必须是编译时常量,以便形状推理可以确定生成的值的类型。

这类似于 C++ 中的 std::get<int N>(t)。从概念上讲:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.

另请参阅 tf.tuple

信息流

另请参阅 XlaBuilder::Infeed

Infeed(shape)

参数 类型 语义
shape Shape 从信息流接口读取的数据的形状。形状的布局字段必须设置为与发送到设备的数据的布局一致;否则,其行为将不明确。

从设备的隐式信息流接口读取单个数据项,将数据解读为给定形状及其布局,并返回数据的 XlaOp。计算中允许有多个 infeed 操作,但这些 infeed 操作之间必须有总顺序。例如,以下代码中的两个 infeed 具有总顺序,因为 while 循环之间存在依赖项。

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

不支持嵌套元组形状。对于空元组形状,信息流操作实际上是无操作,并且会继续执行,而不会从设备的信息流中读取任何数据。

Iota

另请参阅 XlaBuilder::Iota

Iota(shape, iota_dimension)

在设备上构建常量字面量,而不是可能较大的主机传输。创建一个具有指定形状的数组,其中包含从零开始并沿指定维度递增 1 的值。对于浮点类型,生成的数组等同于 ConvertElementType(Iota(...)),其中 Iota 为整数类型,并且转换为浮点类型。

参数 类型 语义
shape Shape Iota() 创建的数组的形状
iota_dimension int64 要沿其递增的维度。

例如,Iota(s32[4, 8], 0) 会返回

  [[0, 0, 0, 0, 0, 0, 0, 0 ],
   [1, 1, 1, 1, 1, 1, 1, 1 ],
   [2, 2, 2, 2, 2, 2, 2, 2 ],
   [3, 3, 3, 3, 3, 3, 3, 3 ]]

退货费用 Iota(s32[4, 8], 1)

  [[0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ]]

地图

另请参阅 XlaBuilder::Map

Map(operands..., computation)

参数 类型 语义
operands 一系列 N 个 XlaOp 类型为 T0..T{N-1} 的 N 个数组
computation XlaComputation 类型为 T_0, T_1, .., T_{N + M -1} -> S 的计算,其中包含类型为 T 的 N 个参数和类型为任意的 M 个参数
dimensions int64 数组 映射维度数组

对给定的 operands 数组应用标量函数,生成一个维度相同的数组,其中每个元素都是应用于输入数组中相应元素的映射函数的结果。

映射的函数是任意计算,但有以下限制:它有 N 个标量类型为 T 的输入,以及一个类型为 S 的输出。输出的维度与运算数相同,但元素类型 T 会替换为 S。

例如:Map(op1, op2, op3, computation, par1) 会将 elem_out <- computation(elem1, elem2, elem3, par1) 映射到输入数组中的每个(多维)索引,以生成输出数组。

OptimizationBarrier

阻止任何优化阶段将计算移出屏障。

确保在任何依赖于屏障输出的运算符之前对所有输入进行求值。

垫子

另请参阅 XlaBuilder::Pad

Pad(operand, padding_value, padding_config)

参数 类型 语义
operand XlaOp 类型为 T 的数组
padding_value XlaOp 类型为 T 的标量,用于填充添加的内边距
padding_config PaddingConfig 两侧的内边距(低、高)以及每个维度元素之间的内边距

使用给定 padding_value 在给定 operand 数组周围以及数组元素之间添加内边距,从而扩展给定 operand 数组。padding_config 用于指定每个尺寸的外边内边距和内部内边距。

PaddingConfigPaddingConfigDimension 的重复字段,其中包含每个维度的三个字段:edge_padding_lowedge_padding_highinterior_padding

edge_padding_lowedge_padding_high 分别指定在每个维度的低端(靠近索引 0)和高端(靠近最高索引)添加的内边距量。边缘内边距的大小可以为负值,负内边距的绝对值表示要从指定维度中移除的元素数量。

interior_padding 用于指定在每个维度中的任意两个元素之间添加的内边距;该值不得为负。从逻辑上讲,内部内边距会先于边缘内边距发生,因此在负边缘内边距的情况下,系统会从内边距内运算对象中移除元素。

如果边缘内边距对均为 (0, 0),且内部内边距值均为 0,则此操作将不执行任何操作。下图显示了二维数组的不同 edge_paddinginterior_padding 值示例。

Recv

另请参阅 XlaBuilder::Recv

Recv(shape, channel_handle)

参数 类型 语义
shape Shape 要接收的数据的形状
channel_handle ChannelHandle 每个 send/recv 对的唯一标识符

从共享相同通道句柄的其他计算中的 Send 指令接收给定形状的数据。返回收到的数据的 XlaOp。

Recv 操作的客户端 API 代表同步通信。不过,该指令会在内部分解为 2 个 HLO 指令(RecvRecvDone),以实现异步数据传输。另请参阅 HloInstruction::CreateRecvHloInstruction::CreateRecvDone

Recv(const Shape& shape, int64 channel_id)

分配从具有相同 channel_id 的 Send 指令接收数据所需的资源。返回已分配资源的上下文,后续的 RecvDone 指令会使用该上下文来等待数据传输完成。上下文是 {接收缓冲区(形状)、请求标识符 (U32)} 的元组,只能由 RecvDone 指令使用。

RecvDone(HloInstruction context)

给定由 Recv 指令创建的上下文,等待数据传输完成并返回收到的数据。

减少

另请参阅 XlaBuilder::Reduce

将一个归约函数并行应用于一个或多个数组。

Reduce(operands..., init_values..., computation, dimensions)

参数 类型 语义
operands 一系列 N 个 XlaOp 类型为 T_0, ..., T_{N-1} 的 N 个数组。
init_values 一系列 N 个 XlaOp 类型为 T_0, ..., T_{N-1} 的 N 个标量。
computation XlaComputation 类型为 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 的计算。
dimensions int64 数组 要缩减的维度的无序数组。

其中:

  • N 必须大于或等于 1。
  • 计算必须“大致”遵循结合律(见下文)。
  • 所有输入数组的维度都必须相同。
  • 所有初始值都必须在 computation 下形成一个标识。
  • 如果为 N = 1,则 Collate(T)T
  • 如果为 N > 1Collate(T_0, ..., T_{N-1}) 是一个由类型为 TN 元素组成的元组。

此操作会将每个输入数组的一个或多个维度缩减为标量。每个返回的数组的维度数为 number_of_dimensions(operand) - len(dimensions)。该运算的输出为 Collate(Q_0, ..., Q_N),其中 Q_i 是类型为 T_i 的数组,其维度如下所述。

不同的后端可以重新关联求和计算。这可能会导致数值差异,因为一些求和等求值函数对浮点数而言并不具有结合性。不过,如果数据范围有限,对于大多数实际用途,浮点加法足以接近于具有关联性。

示例

使用值为 [10, 11, 12, 13] 的单个 1D 数组中的一个维度进行求和时,如果使用求和函数 f(即 computation),则可以按如下方式计算:

f(10, f(11, f(12, f(init_value, 13)))

但还有许多其他可能性,例如

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))

以下是一个粗略的伪代码示例,展示了如何使用求和作为初始值为 0 的求和计算来实现求和。

result_shape <- remove all dims in dimensions from operand_shape

# Iterate over all elements in result_shape. The number of r's here is equal
# to the number of dimensions of the result.
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # Initialize this result element
  result[r0, r1...] <- 0

  # Iterate over all the reduction dimensions
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # Increment the result element with the value of the operand's element.
    # The index of the operand's element is constructed from all ri's and di's
    # in the right order (by construction ri's and di's together index over the
    # whole operand shape).
    result[r0, r1...] += operand[ri... di]

下面是一个对二维数组(矩阵)进行还原的示例。该形状有 2 个维度,维度 0 的大小为 2,维度 1 的大小为 3:

使用“add”函数减少维度 0 或 1 的结果:

请注意,这两个求和结果都是 1 维数组。图表中将一个显示为列,另一个显示为行,只是为了方便视觉呈现。

下面是一个 3D 数组,是一个更复杂的示例。其维度数为 3,维度 0 的大小为 4,维度 1 的大小为 2,维度 2 的大小为 3。为简单起见,值 1 到 6 会复制到维度 0 中。

与二维示例类似,我们只需减少一个维度。例如,如果我们对第 0 个维度进行还原,则会得到一个二维数组,其中第 0 个维度中的所有值都被折叠为一个标量:

|  4   8  12 |
| 16  20  24 |

如果我们减少第 2 个维度,我们也会得到一个二维数组,其中第 2 个维度中的所有值都被折叠为一个标量:

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

请注意,输入中其余维度之间的相对顺序会保留在输出中,但由于维度数量会发生变化,因此某些维度可能会被分配新的编号。

我们还可以减少多个维度。对维度 0 和 1 进行加法减化会产生 1D 数组 [20, 28, 36]

对 3D 数组的所有维度进行求和会生成标量 84

可变参数求值

N > 1 时,reduce 函数应用稍微复杂一些,因为它会同时应用于所有输入。运算数会按以下顺序提供给计算:

  • 第一个运算元的运行减值
  • 第 N 个运算数的运行减值
  • 第一个操作数的输入值
  • 第 N 个运算数的输入值

例如,请考虑以下求和函数,该函数可用于并行计算 1 维数组的最大值和 argmax:

f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
  if value >= max:
    return (value, index)
  else:
    return (max, argmax)

对于 1 维输入数组 V = Float[N], K = Int[N] 和初始值 I_V = Float, I_K = Int,对唯一输入维度进行求和的结果 f_(N-1) 等同于以下递归应用:

f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))

将此求值运算应用于一个值数组和一个顺序索引数组(即 iota)会对这两个数组进行并行迭代,并返回一个包含最大值和匹配索引的元组。

ReducePrecision

另请参阅 XlaBuilder::ReducePrecision

模拟将浮点值转换为精度较低的格式(例如 IEEE-FP16)并恢复为原始格式的影响。低精度格式中的指数和小数部分位数可以任意指定,但并非所有硬件实现都支持所有位大小。

ReducePrecision(operand, mantissa_bits, exponent_bits)

参数 类型 语义
operand XlaOp 浮点类型 T 的数组。
exponent_bits int32 较低精度格式的指数位数
mantissa_bits int32 低精度格式的小数部分位数

结果为 T 类型的数组。输入值会舍入为可用给定小数位数表示的最接近的值(使用“四舍五入”语义),并且任何超出指数位数指定范围的值都会被钳制为正无穷大或负无穷大。NaN 值会保留,但可能会转换为规范的 NaN 值。

精度较低的格式必须至少包含一个指数位(为了区分零值和无穷大值,因为这两个值的尾数均为零),并且尾数位数必须为非负数。指数或小数部分位数可能会超过类型 T 的相应值;然后,转换的相应部分只是无操作。

ReduceScatter

另请参阅 XlaBuilder::ReduceScatter

ReduceScatter 是一种集合操作,可有效执行 AllReduce,然后通过沿 scatter_dimension 将结果拆分为 shard_count 个分块来分散结果,并且副本组中的副本 i 会接收 ith 分片。

ReduceScatter(operand, computation, scatter_dim, shard_count, replica_group_ids, channel_id)

参数 类型 语义
operand XlaOp 要在各个副本之间进行求和的数组或数组的非空元组。
computation XlaComputation 减排量计算
scatter_dimension int64 要散点的维度。
shard_count int64 要将 scatter_dimension 拆分成的区块数量
replica_groups int64 的矢量矢量 执行缩减操作的组
channel_id 可选 int64 用于跨模块通信的可选通道 ID
  • operand 是数组元组时,系统会对元组的每个元素执行 reduce-scatter。
  • replica_groups 是执行缩减操作的副本组的列表(可以使用 ReplicaId 检索当前副本的副本 ID)。每个组中的副本顺序决定了全局缩减结果的分散顺序。replica_groups 必须为空(在这种情况下,所有副本都属于单个组),或者包含与副本数量相同的元素。如果有多个副本组,则它们的大小都必须相同。例如,replica_groups = {0, 2}, {1, 3} 会在副本 02 以及 13 之间执行求和,然后分散结果。
  • shard_count 是每个副本组的大小。在 replica_groups 为空的情况下,我们需要此属性。如果 replica_groups 不为空,则 shard_count 必须等于每个副本组的大小。
  • channel_id 用于跨模块通信:只有具有相同 channel_idreduce-scatter 操作才能相互通信。

输出形状是输入形状,其中 scatter_dimension 缩小了 shard_count 倍。例如,如果有两个副本,并且操作数在两个副本上的值分别为 [1.0, 2.25][3.0, 5.25],则此运算(其中 scatter_dim0)的输出值对于第一个副本为 [4.0],对于第二个副本为 [7.5]

ReduceWindow

另请参阅 XlaBuilder::ReduceWindow

对一系列 N 个多维数组的每个窗口中的所有元素应用一个求和函数,以生成一个或一个包含 N 个多维数组的元组作为输出。每个输出数组的元素数与窗口的有效位置数相同。池化层可以表示为 ReduceWindow。与 Reduce 类似,应用的 computation 始终会传递给左侧的 init_values

ReduceWindow(operands..., init_values..., computation, window_dimensions, window_strides, padding)

参数 类型 语义
operands N XlaOps 一系列类型为 T_0,..., T_{N-1} 的 N 个多维数组,每个数组都代表窗口放置的基准区域。
init_values N XlaOps 求和的 N 个起始值,每个操作数对应一个值。如需了解详情,请参阅减少
computation XlaComputation 类型为 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 的求和函数,用于应用于所有输入运算元的每个窗口中的元素。
window_dimensions ArraySlice<int64> 窗口尺寸值的整数数组
window_strides ArraySlice<int64> 窗口步长值的整数数组
base_dilations ArraySlice<int64> 基底膨胀值的整数数组
window_dilations ArraySlice<int64> 窗口膨胀值的整数数组
padding Padding 窗口的填充类型(Padding::kSame,如果步长为 1,则进行填充以使输出形状与输入相同;或者 Padding::kValid,不使用填充,并在窗口不再适合时“停止”窗口)

其中:

  • N 必须大于或等于 1。
  • 所有输入数组的维度都必须相同。
  • 如果为 N = 1,则 Collate(T)T
  • 如果为 N > 1Collate(T_0, ..., T_{N-1}) 是一个由类型为 (T0,...T{N-1})N 元素组成的元组。

以下代码和图展示了使用 ReduceWindow 的示例。输入是一个大小为 [4x6] 的矩阵,并且 window_dimensions 和 window_stride_dimensions 均为 [2x3]。

// Create a computation for the reduction (maximum).
XlaComputation max;
{
  XlaBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().value();
}

// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    *max,
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

在某个维度中设置步长为 1 表示该维度中某个窗口的位置距离相邻窗口 1 个元素。如需指定各个窗口互不重叠,window_stride_dimensions 应等于 window_dimensions。下图展示了使用两个不同步长值的情况。填充会应用于输入的每个维度,并且计算结果与输入以填充后的维度进入时相同。

如需查看一个非琐碎的填充示例,不妨考虑对输入数组 [10000, 1000, 100, 10, 1] 计算维度为 3 且步长为 2 的 reduce 窗口最小值(初始值为 MAX_FLOAT)。填充 kValid 会计算两个有效窗口([10000, 1000, 100][100, 10, 1])中的最小值,从而生成输出 [100, 1]。填充 kSame 会先对数组进行填充,以便通过在两侧添加初始元素,使求和窗口后的形状与步长 1 的输入相同,从而得到 [MAX_VALUE, 10000, 1000, 100, 10, 1, MAX_VALUE]。对内边距数组运行 reduce-window 会对三个窗口 [MAX_VALUE, 10000, 1000][1000, 100, 10][10, 1, MAX_VALUE] 进行操作,并生成 [1000, 10, 1]

归约函数的求值顺序是任意的,并且可能是非确定性的。因此,求和函数不应对重新关联过于敏感。如需了解详情,请参阅 Reduce 上下文中关于关联性的讨论。

ReplicaId

另请参阅 XlaBuilder::ReplicaId

返回副本的唯一 ID (U32 标量)。

ReplicaId()

每个副本的唯一 ID 是 [0, N) 范围内的无符号整数,其中 N 是副本数量。由于所有副本都运行相同的程序,因此程序中的 ReplicaId() 调用将在每个副本上返回不同的值。

Reshape

另请参阅 XlaBuilder::ReshapeCollapse 运算。

将数组的维度重新调整为新配置。

Reshape(operand, dimensions)

参数 类型 语义
operand XlaOp 类型为 T 的数组
dimensions int64 矢量 新维度的大小向量

从概念上讲,reshape 会先将数组展平为数据值的一维矢量,然后将此矢量提炼为新形状。输入参数是类型为 T 的任意数组、维度索引的编译时常量矢量,以及结果的维度大小的编译时常量矢量。dimensions 矢量决定了输出数组的大小。dimensions 中索引 0 处的值是尺寸 0 的大小,索引 1 处的值是尺寸 1 的大小,依此类推。dimensions 维度的乘积必须等于运算元的维度大小的乘积。将展开的数组细化为由 dimensions 定义的多维数组时,dimensions 中的维度会按变化最慢(最主要)到变化最快(最次要)的顺序排列。

例如,设 v 为一个包含 24 个元素的数组:

let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
                    { {20, 21, 22}, {25, 26, 27} },
                    { {30, 31, 32}, {35, 36, 37} },
                    { {40, 41, 42}, {45, 46, 47} } };

let v012_24 = Reshape(v, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47} };

作为一种特殊情况,reshape 可以将单元素数组转换为标量,反之亦然。例如,

Reshape(f32[1x1] { {5} }, {}) == 5;
Reshape(5, {1,1}) == f32[1x1] { {5} };

倒车

另请参阅 XlaBuilder::Rev

Rev(operand, dimensions)

参数 类型 语义
operand XlaOp 类型为 T 的数组
dimensions ArraySlice<int64> 要反转的维度

沿指定的 dimensions 反转 operand 数组中元素的顺序,生成形状相同的输出数组。在多维索引处,运算数数组的每个元素都会存储在经过转换的索引处的输出数组中。多维索引的转换方法是,对要反转的每个维度的索引进行反转(即,如果大小为 N 的维度是反转维度之一,则其索引 i 会转换为 N - 1 - i)。

Rev 运算的一个用途是在神经网络中计算梯度时,沿两个窗口维度反转卷积权重数组。

RngNormal

另请参阅 XlaBuilder::RngNormal

使用按照 N(μ,σ) 正态分布生成的随机数构建给定形状的输出。参数 μσ以及输出形状必须采用浮点元素类型。此外,参数必须是标量值。

RngNormal(mu, sigma, shape)

参数 类型 语义
mu XlaOp 类型为 T 的标量,用于指定生成的数字的均值
sigma XlaOp 类型为 T 的标量,用于指定生成的
shape Shape 类型为 T 的输出形状

RngUniform

另请参阅 XlaBuilder::RngUniform

使用在区间 [a,b)内按照均匀分布生成的随机数字构建给定形状的输出。参数和输出元素类型必须是布尔类型、整数类型或浮点类型,并且类型必须一致。CPU 和 GPU 后端目前仅支持 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,参数需要是标量值。如果为 b<=a ,则结果由实现定义。

RngUniform(a, b, shape)

参数 类型 语义
a XlaOp 用于指定间隔下限的类型为 T 的标量
b XlaOp 用于指定间隔上限的类型为 T 的标量
shape Shape 类型为 T 的输出形状

RngBitGenerator

使用指定的算法(或后端默认算法)生成一个形状为给定形状且填充有均匀随机位数的输出,并返回更新后的状态(与初始状态具有相同的形状)和生成的随机数据。

初始状态是当前随机数生成的初始状态。它以及所需的形状和有效值取决于所使用的算法。

输出一定是初始状态的确定性函数,但保证在后端和不同编译器版本之间是确定性的。

RngBitGenerator(algorithm, key, shape)

参数 类型 语义
algorithm RandomAlgorithm 要使用的 PRNG 算法。
initial_state XlaOp PRNG 算法的初始状态。
shape Shape 生成数据的输出形状。

algorithm 的可用值:

散点图

XLA 分散操作会生成一系列结果,即输入数组 operands 的值,其中使用 update_computation 将多个切片(在 scatter_indices 指定的索引处)更新为 updates 中的值序列。

另请参阅 XlaBuilder::Scatter

scatter(operands..., scatter_indices, updates..., update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

参数 类型 语义
operands 一系列 N 个 XlaOp 要分散到其中的 N 个类型为 T_0, ..., T_N 的数组。
scatter_indices XlaOp 包含必须分散到的切片的起始索引的数组。
updates 一系列 N 个 XlaOp 类型为 T_0, ..., T_N 的 N 个数组。updates[i] 包含必须用于散射 operands[i] 的值。
update_computation XlaComputation 用于组合输入数组中的现有值和分散期间的更新的计算。此计算的类型应为 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)
index_vector_dim int64 scatter_indices 中包含起始索引的维度。
update_window_dims ArraySlice<int64> updates 形状的一组尺寸,即窗口尺寸
inserted_window_dims ArraySlice<int64> 必须插入 updates 形状的一组窗口尺寸
scatter_dims_to_operand_dims ArraySlice<int64> 从分散索引到操作数索引空间的维度映射。系统会将此数组解读为将 i 映射到 scatter_dims_to_operand_dims[i]。它必须是一对一且总计。
indices_are_sorted bool 是否保证索引由调用方进行排序。
unique_indices bool 调用方是否保证索引是唯一的。

其中:

  • N 必须大于或等于 1。
  • operands[0]、...、operands[N-1] 必须具有相同的维度。
  • updates[0]、...、updates[N-1] 必须具有相同的维度。
  • 如果为 N = 1,则 Collate(T)T
  • 如果为 N > 1Collate(T_0, ..., T_N) 是一个由类型为 TN 元素组成的元组。

如果 index_vector_dim 等于 scatter_indices.rank,我们会隐式地将 scatter_indices 视为具有尾随 1 维度。

我们将类型为 ArraySlice<int64>update_scatter_dims 定义为 updates 形状中不属于 update_window_dims 的一组维度,按升序排列。

scatter 的参数应遵循以下约束条件:

  • 每个 updates 数组都必须具有 update_window_dims.size + scatter_indices.rank - 1 个维度。

  • 每个 updates 数组中维度 i 的边界必须符合以下要求:

    • 如果 update_window_dims 中存在 i(即对于某个 k,等于 update_window_dims[k]),则在考虑 inserted_window_dims 后,updates 中维度 i 的上限不得超过 operand 的相应上限(即 adjusted_window_bounds[k],其中 adjusted_window_bounds 包含 operand 的上限,但已移除索引 inserted_window_dims 处的上限)。
    • 如果 update_scatter_dims 中存在 i(即,对于某个 k,等于 update_scatter_dims[k]),则 updates 中维度 i 的上限必须等于 scatter_indices 的相应上限,并跳过 index_vector_dim(即,如果 k < index_vector_dim,则为 scatter_indices.shape.dims[k];否则,为 scatter_indices.shape.dims[k+1])。
  • update_window_dims 必须按升序排列,不得有重复的维度编号,且必须在 [0, updates.rank) 范围内。

  • inserted_window_dims 必须按升序排列,不得有重复的维度编号,且必须在 [0, operand.rank) 范围内。

  • operand.rank 必须等于 update_window_dims.sizeinserted_window_dims.size 的总和。

  • scatter_dims_to_operand_dims.size 必须等于 scatter_indices.shape.dims[index_vector_dim],且其值必须在 [0, operand.rank) 范围内。

对于每个 updates 数组中的给定索引 U,必须应用此更新的相应 operands 数组中的相应索引 I 的计算方式如下:

  1. G = { U[k] for k in update_scatter_dims }。使用 Gscatter_indices 数组中查找索引矢量 S,使得 S[i] = scatter_indices[Combine(G, i)],其中 Combine(A, b) 会将 b 插入到 A 的 index_vector_dim 位置。
  2. 使用 scatter_dims_to_operand_dims 映射将 S 散射,然后使用 Soperand 中创建索引 Sin。更正式的形式:
    1. 如果 k < scatter_dims_to_operand_dims.size,则 Sin[scatter_dims_to_operand_dims[k]] = S[k]。
    2. 否则 Sin[_] = 0
  3. 通过根据 inserted_window_dims 将索引散布在 U 中的 update_window_dims,在每个 operands 数组中创建索引 Win。更正式的形式:
    1. 如果 kupdate_window_dims 中,则 Win[window_dims_to_operand_dims(k)] = U[k],其中 window_dims_to_operand_dims 是域为 [0, update_window_dims.size) 且范围为 [0, operand.rank) \ inserted_window_dims 的单调函数。(例如,如果 update_window_dims.size4operand.rank6,并且 inserted_window_dims 为 {0, 2},则 window_dims_to_operand_dims 为 {01, 13, 24, 35})。
    2. 否则 Win[_] = 0
  4. IWin + Sin,其中 + 表示按元素相加。

总的来说,散射操作可以定义如下。

  • 使用 operands 初始化 output,即对于 operands[J] 数组中的所有索引 J,对于所有索引 O
    output[J][O] = operands[J][O]
  • 对于 updates[J] 数组中的每个索引 Uoperand[J] 数组中的对应索引 O,如果 Ooutput 的有效索引:
    (output[0][O], ..., output[N-1][O]) =update_computation(output[0][O], ..., ,output[N-1][O],updates[0][U], ...,updates[N-1][U])

更新的应用顺序是不确定的。因此,当 updates 中的多个索引引用 operands 中的同一索引时,output 中的相应值将是非确定性的。

请注意,传入 update_computation 的第一个参数始终是 output 数组中的当前值,第二个参数始终是 updates 数组中的值。这一点尤其重要,尤其是在 update_computation 不具有交换性的情况下。

如果 indices_are_sorted 设为 true,则 XLA 可以假定 scatter_indices 已由用户排序(按升序排序,根据 scatter_dims_to_operand_dims 散布其值之后)。如果没有,则语义由实现定义。

如果 unique_indices 设为 true,则 XLA 可以假定所有分散到的元素都是唯一的。因此,XLA 可以使用非原子操作。如果 unique_indices 设置为 true,并且要分散到的索引不唯一,则语义由实现定义。

非正式地讲,分散操作可以视为汇总操作的反向操作,即分散操作会更新相应汇总操作提取的输入中的元素。

如需详细的非正式说明和示例,请参阅 Gather 下的“非正式说明”部分。

选择

另请参阅 XlaBuilder::Select

根据谓词数组的值,从两个输入数组的元素构建输出数组。

Select(pred, on_true, on_false)

参数 类型 语义
pred XlaOp 类型为 PRED 的数组
on_true XlaOp 类型为 T 的数组
on_false XlaOp 类型为 T 的数组

数组 on_trueon_false 必须具有相同的形状。这也是输出数组的形状。数组 pred 必须与 on_trueon_false 具有相同的维度,并且采用 PRED 元素类型。

对于 pred 的每个元素 P,如果 P 的值为 true,则输出数组的相应元素从 on_true 中取出;如果 P 的值为 false,则从 on_false 中取出。作为广播的一种受限形式,pred 可以是类型为 PRED 的标量。在这种情况下,如果 predtrue,则输出数组完全取自 on_true;如果 predfalse,则输出数组完全取自 on_false

使用非标量 pred 的示例:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

使用标量 pred 的示例:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

支持在元组之间进行选择。为此,元组被视为标量类型。如果 on_trueon_false 是元组(必须具有相同的形状!),则 pred 必须是类型为 PRED 的标量。

SelectAndScatter

另请参阅 XlaBuilder::SelectAndScatter

此操作可以视为一个复合操作,该操作首先对 operand 数组计算 ReduceWindow 以从每个窗口中选择一个元素,然后将 source 数组分散到所选元素的索引,以构建与运算数组具有相同形状的输出数组。二进制 select 函数用于通过对每个窗口应用该函数来从每个窗口中选择一个元素,并且在调用该函数时,第一个参数的索引矢量在字典顺序上小于第二个参数的索引矢量。如果选择第一个参数,select 函数会返回 true;如果选择第二个参数,则会返回 false;并且该函数必须具有传递性(即,如果 select(a, b)select(b, c)true,则 select(a, c) 也为 true),以便所选元素不依赖于为给定窗口遍历的元素的顺序。

系统会对输出数组中的每个选定索引应用函数 scatter。它接受两个标量参数:

  1. 输出数组中所选索引处的当前值
  2. 应用于所选索引的 source 中的散射值

它会组合这两个参数,并返回一个标量值,用于更新输出数组中所选索引处的值。最初,输出数组的所有索引都设为 init_value

输出数组的形状与 operand 数组相同,并且 source 数组的形状必须与对 operand 数组应用 ReduceWindow 运算的结果相同。SelectAndScatter 可用于对神经网络中的池化层回传梯度值。

SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)

参数 类型 语义
operand XlaOp 窗口滑动所依据的类型为 T 的数组
select XlaComputation 类型为 T, T -> PRED 的二进制计算,适用于每个窗口中的所有元素;如果选择第一个参数,则返回 true;如果选择第二个参数,则返回 false
window_dimensions ArraySlice<int64> 窗口尺寸值的整数数组
window_strides ArraySlice<int64> 窗口步长值的整数数组
padding Padding 窗口的填充类型(Padding::kSame 或 Padding::kValid)
source XlaOp 包含要散布的值的类型为 T 的数组
init_value XlaOp 输出数组的初始值的类型为 T 的标量值
scatter XlaComputation 类型为 T, T -> T 的二进制计算,用于将每个散射源元素应用于其目标元素

下图展示了使用 SelectAndScatter 的示例,其中 select 函数会计算其参数中的最大值。请注意,当窗口重叠时(如图 (2) 所示),不同的窗口可能会多次选择 operand 数组的索引。在图中,值为 9 的元素被顶部的两个窗口(蓝色和红色)同时选中,二进制加法 scatter 函数会生成值为 8(2 + 6)的输出元素。

scatter 函数的求值顺序是任意的,并且可能是非确定性的。因此,scatter 函数不应对重新关联过于敏感。如需了解详情,请参阅 Reduce 上下文中关于关联性的讨论。

发送

另请参阅 XlaBuilder::Send

Send(operand, channel_handle)

参数 类型 语义
operand XlaOp 要发送的数据(类型为 T 的数组)
channel_handle ChannelHandle 每个 send/recv 对的唯一标识符

将给定操作数数据发送到共享相同通道句柄的另一个计算中的 Recv 指令。不会返回任何数据。

Recv 操作类似,Send 操作的客户端 API 表示同步通信,并在内部分解为 2 个 HLO 指令(SendSendDone),以实现异步数据传输。另请参阅 HloInstruction::CreateSendHloInstruction::CreateSendDone

Send(HloInstruction operand, int64 channel_id)

启动将操作数异步传输到具有相同通道 ID 的 Recv 指令分配的资源。返回一个上下文,后续的 SendDone 指令会使用该上下文来等待数据传输完成。上下文是 {操作数(形状)、请求标识符 (U32)} 的元组,并且只能由 SendDone 指令使用。

SendDone(HloInstruction context)

给定由 Send 指令创建的上下文,等待数据传输完成。该指令不会返回任何数据。

通道指令的调度

每个通道(RecvRecvDoneSendSendDone)的 4 条指令的执行顺序如下所示。

  • Recv 发生在 Send 之前
  • Send 发生在 RecvDone 之前
  • Recv 发生在 RecvDone 之前
  • Send 发生在 SendDone 之前

当后端编译器为通过通道指令进行通信的每个计算生成线性调度时,这些计算之间不得有循环。例如,以下时间表会导致死锁。

请注意,指令的限制仅适用于运行时的 TPU。在 GPU 上,sendrecv 会阻塞,并且在源设备和目标设备之间建立握手之前不会发送任何实际数据。

Slice

另请参阅 XlaBuilder::Slice

切片会从输入数组中提取子数组。子数组的维度数与输入相同,并且包含输入数组内边界框内的值,其中边界框的维度和索引作为 slice 操作的参数提供。

Slice(operand, start_indices, limit_indices, strides)

参数 类型 语义
operand XlaOp 类型为 T 的 N 维数组
start_indices ArraySlice<int64> 包含每个维度 slice 的起始索引的 N 个整数的列表。值必须大于或等于零。
limit_indices ArraySlice<int64> 包含每个维度 slice 的结束索引(不含该索引)的 N 个整数的列表。每个值都必须大于或等于相应维度的 start_indices 值,并且小于或等于该维度的大小。
strides ArraySlice<int64> 一个包含 N 个整数的列表,用于确定 slice 的输入步长。该 slice 会选择维度 d 中的每个 strides[d] 元素。

一维示例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

二维示例:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

排序

另请参阅 XlaBuilder::Sort

Sort(operands, comparator, dimension, is_stable)

参数 类型 语义
operands ArraySlice<XlaOp> 要排序的运算数。
comparator XlaComputation 要使用的比较器计算。
dimension int64 排序依据的维度。
is_stable bool 是否应使用稳定排序。

如果只提供一个运算数:

  • 如果操作数是 1 维张量(数组),则结果为排序数组。如果您想将数组按升序排序,则比较器应执行小于比较。正式地,对数组进行排序后,对于所有索引位置 i, ji < j 均为 comparator(value[i], value[j]) = comparator(value[j], value[i]) = falsecomparator(value[i], value[j]) = true

  • 如果运算数的维度数较多,则运算数会沿着所提供的维度进行排序。例如,对于二维张量(矩阵),维度值为 0 时,系统会独立对每个列进行排序,维度值为 1 时,系统会独立对每行进行排序。如果未提供维度编号,则默认选择最后一个维度。对于要排序的维度,系统会采用与一维情况下相同的排序顺序。

如果提供了 n > 1 运算数:

  • 所有 n 运算数都必须是维度相同的张量。张量的元素类型可能不同。

  • 所有运算数都会一起排序,而不是单独排序。从概念上讲,运算数被视为元组。在检查每个运算数在索引位置 ij 的元素是否需要交换时,系统会使用 2 * n 标量参数调用比较器,其中参数 2 * k 对应于 k-th 运算数在位置 i 的值,参数 2 * k + 1 对应于 k-th 运算数在位置 j 的值。因此,通常比较器会将参数 2 * k2 * k + 1 进行比较,并可能使用其他参数对作为平局决胜依据。

  • 结果是一个元组,由按排序顺序(沿提供的维度,如上所述)排列的运算对象组成。元组的 i-th 运算数对应于 Sort 的 i-th 运算数。

例如,如果有三个操作数 operand0 = [3, 1]operand1 = [42, 50]operand2 = [-3.0, 1.1],并且比较器仅使用小于运算符比较 operand0 的值,则排序的输出是元组 ([1, 3], [50, 42], [1.1, -3.0])

如果将 is_stable 设置为 true,则可以保证排序是稳定的,也就是说,如果比较器认为某些元素相等,则会保留相等值的相对顺序。两个元素 e1e2 只有在 comparator(e1, e2) = comparator(e2, e1) = false 时才相等。默认情况下,is_stable 设置为 false。

TopK

另请参阅 XlaBuilder::TopK

TopK 会找出给定张量最后一个维度的 k 个最大或最小元素的值和索引。

TopK(operand, k, largest)

参数 类型 语义
operand XlaOp 要从中提取前 k 个元素的张量。张量必须具有大于或等于 1 的维度。张量最后一个维度的大小必须大于或等于 k
k int64 要提取的元素数量。
largest bool 要提取最大的 k 元素还是最小的 k 元素。

对于 1 维输入张量(数组),查找数组中的 k 个最大或最小条目,并输出两个数组的元组 (values, indices)。因此,values[j]operand 中第 j 大的/小的条目,其编号为 indices[j]

对于具有多个维度的输入张量,沿最后一个维度计算前 k 个条目,并在输出中保留所有其他维度(行)。因此,对于形状为 [A, B, ..., P, Q]Q >= k 的运算数,输出是一个元组 (values, indices),其中:

values.shape = indices.shape = [A, B, ..., P, k]

如果某行中的两个元素相等,则编号较小的元素会显示在前面。

Transpose

另请参阅 tf.reshape 运算。

Transpose(operand)

参数 类型 语义
operand XlaOp 要转置的运算数。
permutation ArraySlice<int64> 如何对维度进行排列。

使用给定排列对操作数维度进行排列,即 ∀ i . 0 ≤ i < number of dimensions ⇒ input_dimensions[permutation[i]] = output_dimensions[i]

这与 Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) 相同。

TriangularSolve

另请参阅 XlaBuilder::TriangularSolve

通过正向或反向代入法,解具有下三角或上三角系数矩阵的线性方程组。沿前导维度广播,此例程会在给定 ab 的情况下,为变量 x 解析矩阵系统 op(a) * x = bx * op(a) = b 之一,其中 op(a)op(a) = aop(a) = Transpose(a)op(a) = Conj(Transpose(a))

TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)

参数 类型 语义
a XlaOp 一个形状为 [..., M, M] 的复数或浮点类型的 > 2 维数组。
b XlaOp 一个形状为 [..., M, K] 且类型相同的 > 2 维数组(如果 left_side 为 true),否则为 [..., K, M]
left_side bool 指示是解析形式为 op(a) * x = b (true) 还是 x * op(a) = b (false) 的系统。
lower bool 是否使用 a 的上三角或下三角。
unit_diagonal bool 如果为 true,则假定 a 的主对角线元素为 1,且不会被访问。
transpose_a Transpose 是否按原样使用 a、对其进行转置或取其共轭转置。

系统仅从 a 的下三角形/上三角形读取输入数据,具体取决于 lower 的值。系统会忽略另一个三角形中的值。输出数据会在同一三角形中返回;另一个三角形中的值由实现定义,可以是任何值。

如果 ab 的维度数大于 2,则将它们视为矩阵批次,其中除了次要 2 个维度之外,所有其他维度都是批次维度。ab 必须具有相同的批次维度。

元组

另请参阅 XlaBuilder::Tuple

一个包含可变数量的数据句柄的元组,每个句柄都有自己的形状。

这类似于 C++ 中的 std::tuple。从概念上讲:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

您可以通过 GetTupleElement 运算解构(访问)元组。

另请参阅 XlaBuilder::While

While(condition, body, init)

参数 类型 语义
condition XlaComputation 类型为 T -> PRED 的 XlaComputation,用于定义循环的终止条件。
body XlaComputation 类型为 T -> T 的 XlaComputation,用于定义循环的正文。
init T conditionbody 参数的初始值。

顺序执行 body,直到 condition 失败。这与许多其他语言中的典型 while 循环类似,但存在以下差异和限制。

  • While 节点会返回 T 类型的值,该值是上次执行 body 的结果。
  • 类型 T 的形状是静态确定的,并且在所有迭代中都必须相同。

计算的 T 参数在第一次迭代中使用 init 值进行初始化,并在每次后续迭代中自动更新为 body 中的新结果。

While 节点的一个主要用例是实现在神经网络中重复执行训练。下面显示了经过简化的伪代码,以及表示计算的图表。您可以在 while_test.cc 中找到该代码。在此示例中,类型 T 是一个 Tuple,其中包含一个用于迭代计数的 int32 和一个用于累加器的 vector[10]。在 1000 次迭代中,循环会不断向累加器添加一个常量矢量。

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}