運算語意

以下說明 XlaBuilder 介面中定義的運算語意。通常,這些作業會一對一對應至 xla_data.proto 中 RPC 介面中定義的作業。

關於名稱的說明:XLA 處理的一般資料類型是 N 維陣列,可保留某些統一型別 (例如 32 位元浮點) 的元素。在整份說明文件中,陣列一詞用於表示任意維度的陣列。為方便起見,特殊案例會有更明確且熟悉的名稱;例如「向量」是 1 維陣列,「矩陣」則是 2D 陣列。

AfterAll

另請參閱 XlaBuilder::AfterAll

AfterAll 接受變數數量的符記並產生單一符記。符記是原始類型,可在副作用運算之間建立執行緒,以便強制排序。AfterAll 可用於代碼彙整,在進行集合作業後排序作業。

AfterAll(operands)

引數 類型 語義學
operands XlaOp 符記的變數數量

AllGather

另請參閱 XlaBuilder::AllGather

在備援機制之間執行連接。

AllGather(operand, all_gather_dim, shard_count, replica_group_ids, channel_id)

引數 類型 語義學
operand XlaOp 在不同備用資源之間串連的陣列
all_gather_dim int64 連接維度
replica_groups int64 的向量向量 執行串連的群組
channel_id 選用 int64 用於跨模組通訊的選用管道 ID
  • replica_groups 是執行連結的複本群組清單 (您可以使用 ReplicaId 擷取目前複本的複本 ID)。每個群組中的複本順序,會決定複本輸入內容在結果中的順序。replica_groups 必須為空白 (在這種情況下,所有備用資源均屬於單一群組,依 0N - 1 排序),或包含與備用資源數量相同的元素數量。例如,replica_groups = {0, 2}, {1, 3} 會在備用資源 02,以及 13 之間執行串連。
  • shard_count 是每個複本群組的大小。在 replica_groups 為空白的情況下,我們需要這個值。
  • channel_id 用於跨模組通訊:只有具有相同 channel_idall-gather 作業才能相互通訊。

輸出形狀是 all_gather_dim 使 shard_count 變大的輸入形狀。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.5][3.0, 5.25] 的值,則此運算子的輸出值 (all_gather_dim0) 會在兩個副本中皆為 [1.0, 2.5, 3.0, 5.25]

AllReduce

另請參閱 XlaBuilder::AllReduce

在備援機制中執行自訂運算。

AllReduce(operand, computation, replica_group_ids, channel_id)

引數 類型 語義學
operand XlaOp 陣列或非空陣列元組,用於在備用資源之間執行縮減作業
computation XlaComputation 減法運算
replica_groups int64 的向量向量 要執行減法運算的群組
channel_id 選用 int64 用於跨模組通訊的選用管道 ID
  • 如果 operand 是陣列的元組,則會對元組的每個元素執行 all-reduce。
  • replica_groups 是執行縮減作業的備份群組清單 (您可以使用 ReplicaId 擷取目前備份的 ID)。replica_groups 必須為空白 (此時所有備份都屬於單一群組),或包含與備份數量相同的元素。例如,replica_groups = {0, 2}, {1, 3} 會在備用資源 0213 之間執行縮減作業。
  • channel_id 用於跨模組通訊:只有具備相同 channel_idall-reduce 作業能相互通訊。

輸出形狀與輸入形狀相同。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.5][3.0, 5.25] 的值,則這項運算和加總運算的輸出值會在兩個副本中皆為 [4.0, 7.75]。如果輸入是元組,輸出內容也會是元組。

計算 AllReduce 的結果時,需要從每個複本取得一個輸入內容,因此如果某個複本執行 AllReduce 節點的次數多於另一個複本,則前者會一直等待。由於副本都執行相同的程式,因此發生這種情況的機率不高,但如果 while 迴圈的條件取決於 infeed 的資料,且 infeed 的資料會導致 while 迴圈在某個副本上重複執行的次數多於另一個副本,就有可能發生這種情況。

AllToAll

另請參閱 XlaBuilder::AllToAll

AllToAll 是集體運算,可將資料從所有核心傳送至所有核心。它包含兩個階段:

  1. 散佈階段。在每個核心上,運算元會沿著 split_dimensions 分割為 split_count 個區塊,並將區塊分散到所有核心,例如第 ith 個區塊會傳送至第 ith 個核心。
  2. 收集階段:每個核心會沿著 concat_dimension 串連收到的區塊。

您可以透過下列方式設定參與的核心:

  • replica_groups:每個 ReplicaGroup 都包含在運算中參與的備用資源 ID 清單 (可使用 ReplicaId 擷取目前備用資源的備用資源 ID)。AllToAll 會依照指定的順序套用於子群組。舉例來說,replica_groups = { {1,2,3}, {4,5,0} } 表示 AllToAll 會套用至複本 {1, 2, 3} 和收集階段,且收到的區塊會以 1、2、3 的順序連接。接著,系統會在複本 4、5、0 中套用另一個 AllToAll,連結順序也是 4、5、0。如果 replica_groups 為空白,則所有複本都屬於同一個群組,並按照其出現順序連結。

需求條件:

  • split_dimension 上運算元的維度大小可被 split_count 整除。
  • 運算元的形狀不是元組。

AllToAll(operand, split_dimension, concat_dimension, split_count, replica_groups)

引數 類型 語義學
operand XlaOp n 維度輸入陣列
split_dimension int64 間隔 [0, n) 中的值,為運算元命名的維度命名
concat_dimension int64 間隔 [0, n) 中的值,為連接分割區塊的維度命名
split_count int64 參與此作業的核心數量。如果 replica_groups 為空白,則應為複本數量;否則,應等於每個群組中的複本數量。
replica_groups ReplicaGroup 向量 每個群組都包含副本 ID 清單。

以下是 Alltoall 的範例。

XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);

在本例中,有 4 個核心參與 Alltoall。在每個核心上,運算元會沿著第 0 個維度分成 4 個部分,因此每個部分的形狀為 f32[4,4]。這 4 個部分會分散到所有核心。然後,每個核心會依照核心 0 至 4 的順序,沿著第 1 個維度連接收到的部分。因此,每個核心的輸出內容都具有 f32[16,4] 的形狀。

BatchNormGrad

如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormGrad原始批次規格化論文

計算批次常態的梯度。

BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)

引數 類型 語義學
operand XlaOp 要正規化的 n 個維陣列 (x)
scale XlaOp 1 個維度陣列 (\(\gamma\))
mean XlaOp 1 維陣列 (\(\mu\))
variance XlaOp 1 個維度陣列 (\(\sigma^2\))
grad_output XlaOp 傳遞至 BatchNormTraining 的漸層 (\(\nabla y\))
epsilon float 隱私損失值 (\(\epsilon\))
feature_index int64 operand 的特徵維度索引

針對特徵維度中的每個地圖項目 (feature_indexoperand 中的特徵維度的索引),運算會根據所有其他維度計算 operandoffsetscale 等方面的梯度。feature_index 必須是 operand 中特徵維度的有效索引。

三個漸層是由下列公式定義 (假設 4D 陣列為 operand,且特徵維度索引為 l,批量為 m 和空間大小 wh):

\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]

輸入 meanvariance 代表批次和空間維度的時刻值。

輸出類型是三個句柄的元組:

輸出內容 類型 語義學
grad_operand XlaOp 輸入 operand 的漸層 ($\nabla x$)
grad_scale XlaOp 相對於輸入 scale 的梯度 ($\nabla \gamma$)
grad_offset XlaOp 輸入 offset 的梯度 ($\nabla \beta$)

BatchNormInference

如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormInference原始批次規格化論文

將陣列正規化,並跨批次和空間維度。

BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)

引數 類型 語義學
operand XlaOp 要正規化的 n 維陣列
scale XlaOp 1 個維陣列
offset XlaOp 1 個維陣列
mean XlaOp 1 個維陣列
variance XlaOp 1 個維陣列
epsilon float 隱私損失值
feature_index int64 operand 中的特徵維度索引

對於特徵維度中的每個特徵 (feature_indexoperand 中特徵維度的索引),此運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand 中的每個元素正規化。feature_index 必須是 operand 中地圖項目維度的有效索引。

BatchNormInference 相當於在不計算每個批次的 meanvariance 的情況下呼叫 BatchNormTraining。而是使用輸入的 meanvariance 做為預估值。此運算的目的在於減少推論的延遲時間,因此名為 BatchNormInference

輸出內容是 n 維的標準化陣列,其形狀與輸入 operand 相同。

BatchNormTraining

如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormTrainingthe original batch normalization paper

跨批次和空間維度將陣列正規化。

BatchNormTraining(operand, scale, offset, epsilon, feature_index)

引數 類型 語義學
operand XlaOp 要正規化的 n 個維陣列 (x)
scale XlaOp 1 維陣列 (\(\gamma\))
offset XlaOp 1 維陣列 (\(\beta\))
epsilon float 隱私損失值 (\(\epsilon\))
feature_index int64 operand 的特徵維度索引

對於特徵維度中的每個特徵 (feature_indexoperand 中特徵維度的索引),此運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand 中的每個元素正規化。feature_index 必須是 operand 中地圖項目維度的有效索引。

這個演算法適用於 operand \(x\) 中的每個批次,其中包含的 m 元素具有 wh 做為空間維度大小 (假設 operand 是 4 維陣列):

  • 計算特徵維度中每個特徵 l 的批次平均值 \(\mu_l\) : \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)

  • 計算批次變異數 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$

  • 正規化、縮放及位移:\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)

為了避免除以零的錯誤,系統會加入通常為小數的 epsilon 值。

輸出類型是三個 XlaOp 的元組:

輸出內容 類型 語義學
output XlaOp 形狀與輸入內容 operand 相同的 ND 陣列 (y)
batch_mean XlaOp 1 維陣列 (\(\mu\))
batch_var XlaOp 1 維陣列 (\(\sigma^2\))

使用上述公式,按批次和空間維度計算 batch_meanbatch_var 時刻。

BitcastConvertType

另請參閱 XlaBuilder::BitcastConvertType

與 TensorFlow 中的 tf.bitcast 類似,執行從資料形狀到目標形狀的元素位元投放作業。輸入和輸出大小必須相符:例如,s32 元素會透過位元組轉換例行程序成為 f32 元素,而一個 s32 元素會成為四個 s8 元素。位元轉換會以低階轉換方式實作,因此具有不同浮點表示法的機器會產生不同的結果。

BitcastConvertType(operand, new_element_type)

引數 類型 語義學
operand XlaOp 具有維度 D 的類型 T 陣列
new_element_type PrimitiveType 類型 U

除了最後一個維度會根據轉換前後的元素大小比率而變更外,運算元和目標形狀的維度必須相符。

來源和目的地元素類型不得為元組。

將位元組轉換為不同寬度的原始類型

BitcastConvert HLO 指令支援輸出元素類型 T' 的大小不等於輸入元素 T 的大小。由於整個運算的概念是點陣圖,且不會變更基礎位元組,因此輸出元素的形狀必須改變。對於 B = sizeof(T), B' = sizeof(T'),有兩種可能的情況。

首先,當 B > B' 時,輸出形狀會取得新的最小維度大小 B/B'。例如:

  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)

有效量值的規則維持不變:

  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)

或者,如果是 B' > B,指令需要輸入形狀的最後一個邏輯維度等於 B'/B,而這個維度會在轉換期間捨棄:

  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)

請注意,不同位元寬之間的轉換並非元素式轉換。

廣播

另請參閱 XlaBuilder::Broadcast

複製陣列中的資料,為陣列新增維度。

Broadcast(operand, broadcast_sizes)

引數 類型 語義學
operand XlaOp 要複製的陣列
broadcast_sizes ArraySlice<int64> 新尺寸的大小

左側會插入新維度,也就是說,如果 broadcast_sizes 的值為 {a0, ..., aN},且運算元形狀的維度是 {b0, ..., bM},則輸出形狀的尺寸會是 {a0, ..., aN, b0, ..., bM}

新維度會編入運算元副本,也就是

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

例如,如果 operand 是值為 2.0f 的純量 f32,而 broadcast_sizes{2, 3},則結果會是具有形狀 f32[2, 3] 的陣列,且結果中的所有值都會是 2.0f

BroadcastInDim

另請參閱 XlaBuilder::BroadcastInDim

透過複製陣列中的資料,擴大陣列的大小和排名。

BroadcastInDim(operand, out_dim_size, broadcast_dimensions)

引數 類型 語義學
operand XlaOp 要複製的陣列
out_dim_size ArraySlice<int64> 指定形狀的尺寸大小
broadcast_dimensions ArraySlice<int64> 運算元形狀中每個維度對應的是目標形狀中的哪個維度

與 Broadcast 類似,但可在任何位置新增維度,並擴充大小為 1 的現有維度。

operand 會廣播至 out_dim_size 描述的形狀。broadcast_dimensions 會將 operand 的維度對應至目標形狀的維度,也就是將運算元的第 i 維度對應至輸出形狀的 broadcast_dimension[i] 維度。operand 的維度必須為 1,或與所對應輸出形狀中的維度相同。其餘維度則會填入大小為 1 的維度。拆解維度廣播,然後根據這些去生成維度播送,以達到輸出形狀。如要進一步瞭解語意,請參閱廣播頁面

撥打電話

另請參閱 XlaBuilder::Call

使用指定引數叫用運算。

Call(computation, args...)

引數 類型 語義學
computation XlaComputation 帶有任意類型的 N 參數的 T_0, T_1, ..., T_{N-1} -> S 類型計算
args N 個 XlaOp 的序列 N 個任意型別的引數

args 的位數和類型必須與 computation 的參數相符。允許不含 args

Cholesky

另請參閱 XlaBuilder::Cholesky

計算一批對稱 (Hermitian) 正定矩陣的 Cholesky 分解

Cholesky(a, lower)

引數 類型 語義學
a XlaOp 排名 > 2 陣列的複雜或浮點類型。
lower bool 要使用 a 的上方或下三角形。

如果 lowertrue,則會計算下三角矩陣 l,讓 $a = l。l^T$。如果 lowerfalse,則會計算上三角矩陣 u,以便\(a = u^T . u\)。

輸入資料只會從 a 的下/上三角讀取,具體取決於 lower 的值。系統會忽略其他三角形的值。輸出資料會在同一三角形中傳回;其他三角形中的值是由實作定義,可能為任何值。

如果 a 的秩大於 2,a 會視為一批矩陣,其中除了次要 2 個維度以外,所有都是批次維度。

如果 a 不是對稱 (Hermitian) 正定矩陣,則結果會由實作定義。

夾子

另請參閱 XlaBuilder::Clamp

將運算元限制在最小值和最大值之間的範圍內。

Clamp(min, operand, max)

引數 類型 語義學
min XlaOp T 類型的陣列
operand XlaOp 類型為 T 的陣列
max XlaOp 類型為 T 的陣列

給定運算元和最小值與最大值,如果運算元位於最小值和最大值之間,則會傳回運算元;如果運算元低於這個範圍,則會傳回最小值;如果運算元高於這個範圍,則會傳回最大值。也就是說,clamp(a, x, b) = min(max(a, x), b)

這三個陣列的形狀必須相同。或者,由於廣播的限制形式,min 和/或 max 可以是 T 類型的純量。

使用純量 minmax 的範例:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

收合

另請參閱 XlaBuilder::Collapsetf.reshape 作業。

將陣列的維度縮減為一個維度。

Collapse(operand, dimensions)

引數 類型 語義學
operand XlaOp 類型為 T 的陣列
dimensions int64 向量 以順序排列的 T 維度連續子集。

Collapse 會將運算式維度的指定子集取代為單一維度。輸入引數是任意型別 T 的陣列,以及維度索引的編譯時間常數向量。維度索引必須是順序 (由低到高的維度數字),且是 T 維度的連續子集。因此,{0, 1, 2}、{0, 1} 或 {1, 2} 都是有效的維度組合,但 {1, 0} 或 {0, 2} 則無效。這些維度會被單一新維度取代,且在維度序列中的位置與所取代的維度相同,新維度的大小等於原始維度大小的乘積。dimensions 中最低的維度編號,是迴圈巢狀結構中變化最慢的維度 (最主要),該巢狀結構會折疊這些維度,而最高的維度編號則是變化最快的維度 (最次要)。如需更通用的摺疊順序,請參閱 tf.reshape 運算子。

例如,讓 v 為 24 個元素的陣列:

let v = f32[4x2x3] { { {10, 11, 12},  {15, 16, 17} },
{ {20, 21, 22},  {25, 26, 27} },
{ {30, 31, 32},  {35, 36, 37} },
{ {40, 41, 42},  {45, 46, 47} } };

// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};

// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };

// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };

CollectivePermute

另請參閱 XlaBuilder::CollectivePermute

CollectivePermute 是一種集體運算,可跨副本傳送及接收資料。

CollectivePermute(operand, source_target_pairs)

引數 類型 語義學
operand XlaOp n 維輸入陣列
source_target_pairs <int64, int64> 向量 包含 (source_replica_id, target_replica_id) 組合的清單。對於每個組合,運算元會從來源備份傳送至目標備份。

請注意,source_target_pair 有下列限制:

  • 任何兩組都不得使用相同的目標備援 ID,且不得使用相同的來源備援 ID。
  • 如果備用資源 ID 不是任何組合中的目標,則該備用資源的輸出會是包含 0 個、和輸入相同形狀的張量。

串連

另請參閱 XlaBuilder::ConcatInDim

連接會從多個陣列運算元組合陣列。陣列的等級與每個輸入陣列運算元相同 (必須與其他運算元等級相同),並按照指定的順序包含引數。

Concatenate(operands..., dimension)

引數 類型 語義學
operands N 個 XlaOp 的序列 類型為 T 的 N 個陣列,其中的維度 [L0、L1、...],須有 N >= 1。
dimension int64 [0, N) 區間中的值,用於命名要在 operands 之間連接的維度。

除了 dimension 以外,所有尺寸都必須相同。這是因為 XLA 不支援「不連續」陣列。請注意,rank-0 值無法串連 (因為無法命名串連發生的維度)。

1 維範例:

Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}

2D 範例:

let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}

圖:

條件式

另請參閱 XlaBuilder::Conditional

Conditional(pred, true_operand, true_computation, false_operand, false_computation)

引數 類型 語義學
pred XlaOp PRED 類型的純量
true_operand XlaOp 類型為 \(T_0\)的引數
true_computation XlaComputation 「 \(T_0 \to S\)」類型的 XlaComputation
false_operand XlaOp 「 \(T_1\)」類型的引數
false_computation XlaComputation 類型為 \(T_1 \to S\)的 XlaComputation

如果 predtrue,就會執行 true_computation,如果 predfalse,則會執行 false_computation,並傳回結果。

true_computation 必須採用單一 \(T_0\) 類型引數,並且會以 true_operand 呼叫,且 true_operand 必須屬於相同類型。false_computation 必須使用一個類型的單一引數 \(T_1\) ,並透過 false_operand 叫用 (必須屬於相同類型)。true_computationfalse_computation 的傳回值類型必須相同。

請注意,系統會根據 pred 的值執行 true_computationfalse_computation 其中一個。

Conditional(branch_index, branch_computations, branch_operands)

引數 類型 語義學
branch_index XlaOp S32 類型的純量
branch_computations N XlaComputation 序列 類型為 \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)的 XlaComputation
branch_operands N 個 XlaOp 的序列 \(T_0 , T_1 , ..., T_{N-1}\)類型的引數

執行 branch_computations[branch_index],並傳回結果。如果 branch_index 是小於 0 或大於等於 N 的 S32,則會以預設分支版本執行 branch_computations[N-1]

每個 branch_computations[b] 都必須使用單一 \(T_b\) 類型引數,並且會以 branch_operands[b] 叫用,且必須為相同類型。每個 branch_computations[b] 的傳回值類型必須相同。

請注意,系統會根據 branch_index 的值執行其中一個 branch_computations

Conv (卷積)

另請參閱 XlaBuilder::Conv

作為 ConvWithGeneralPadding,但邊框間距是以短方式指定為 SAME 或 VALID。SAME 填充會以零填充輸入內容 (lhs),讓輸出內容的形狀與不考量步幅時的輸入內容相同。VALID 邊框間距單純表示沒有邊框間距。

ConvWithGeneralPadding (卷積)

另請參閱 XlaBuilder::ConvWithGeneralPadding

計算類神經網路中該種類的捲積。這裡的捲積可以想成是 ND 窗型,跨越 N 維基本區域移動,並且會對視窗的每個可能位置執行運算。

引數 類型 語義學
lhs XlaOp 排名 n+2 輸入陣列
rhs XlaOp 權重 n+2 陣列的核函式
window_strides ArraySlice<int64> 核步長的 n-d 陣列
padding ArraySlice< pair<int64,int64>> 填補的 n-d 陣列 (低、高)
lhs_dilation ArraySlice<int64> n-d 左手邊擴張因數陣列
rhs_dilation ArraySlice<int64> n-d 右側值擴張係數陣列
feature_group_count int64 特徵群組數量
batch_group_count int64 批次群組的數量

讓 n 代表空間維度數量。lhs 引數是描述基礎區域的 n+2 等級陣列。雖然 rh 也是輸入內容在類神經網路中,這些是輸入啟用。n+2 維度的順序如下:

  • batch:這個維度的每個座標都代表要執行卷積的獨立輸入。
  • z/depth/features:基本區域中的每個 (y,x) 位置都有一個相關聯的向量,會進入這個維度。
  • spatial_dims:說明 n 空間維度,定義視窗移動的基礎區域。

rhs 引數是描述卷積濾鏡/核/視窗的 n+2 等級陣列。維度依序如下:

  • output-z:輸出內容的 z 維度。
  • input-z:這個維度的大小乘以 feature_group_count 應等於左側 z 維度的大小。
  • spatial_dims:說明 n 空間維度,定義在基本區域中移動的 n 維視窗。

window_strides 引數會指定空間維度的捲積窗口。舉例來說,如果第一個空間維度的步幅為 3,則視窗只能放在第一個空間索引可被 3 整除的座標。

padding 引數會指定要套用至基底區域的零邊框間距數量。填充量可以為負值,負值填充的絕對值表示在執行卷積之前,從指定維度移除的元素數量。padding[0] 指定維度 y 的邊框,padding[1] 則指定維度 x 的邊框。每個組合的第一個元素為低邊框間距,第二個元素為高邊框間距。低邊框間距會按照較低索引的方向套用,高邊框間距則按照較高索引的方向套用。舉例來說,如果 padding[1](2,3),則第二個空間維度會在左側加上 2 個零,在右側加上 3 個零。使用填充功能,就等同於在進行卷積之前,將相同的零值插入輸入內容 (lhs)。

lhs_dilationrhs_dilation 引數會指定在每個空間維度中,分別套用至左側和右側的擴張因子。如果空間維度中的色差係數為 d,則該維度中的每個項目之間會以隱含方式放置 d-1 孔,增加陣列大小。這些孔會填補無運算值,而卷積代表為零。

對右側邊緣進行擴張也稱為「擴張卷積」。詳情請參閱 tf.nn.atrous_conv2d。左側的擴張也稱為轉置的卷積。詳情請參閱 tf.nn.conv2d_transpose

feature_group_count 引數 (預設值為 1) 可用於分組卷積。feature_group_count 必須是輸入和輸出特徵維度的除數。如果 feature_group_count 大於 1,表示在概念上,輸入和輸出特徵維度以及 rhs 輸出特徵維度會平均分割成許多 feature_group_count 群組,每個群組都包含特徵的連續子序列。rhs 的輸入特徵維度必須等於 lhs 輸入特徵維度除以 feature_group_count (也就是說,它已具有一組輸入特徵的大小)。第 i 組會一起用於計算許多個獨立卷積的 feature_group_count。這些卷積的結果會在輸出特徵維度中連接在一起。

針對深度卷積,feature_group_count 引數會設為輸入特徵維度,而篩選器會從 [filter_height, filter_width, in_channels, channel_multiplier] 重塑為 [filter_height, filter_width, 1, in_channels * channel_multiplier]。詳情請參閱 tf.nn.depthwise_conv2d

在反向傳播期間,您可以使用 batch_group_count (預設值 1) 引數為分組篩選器。batch_group_count 必須是 lhs (輸入) 批次維度的大小除數。如果 batch_group_count 大於 1,表示輸出批次維度的大小應為 input batch / batch_group_countbatch_group_count 必須是輸出特徵大小的除數。

輸出形狀的維度如下:

  • batch:這個維度的大小乘以 batch_group_count 應等於左側 batch 維度的大小。
  • z:與核心 (rhs) 上的 output-z 相同大小。
  • spatial_dims:每個卷積視窗的有效位置對應一個值。

上圖顯示 batch_group_count 欄位的運作方式。實際上,我們會將每個左手邊批次切割成 batch_group_count 群組,並對輸出功能執行相同的操作。然後,針對這些群組,我們會成對的捲積,並將輸出和輸出特徵維度串連起來。所有其他維度的作業語意 (地圖和空間) 都保持不變。

卷積視窗的有效位置取決於步幅和填充後的底部區域大小。

為了說明卷積的運作方式,請考慮 2D 卷積,並在輸出內容中選取一些固定的 batchzyx 座標。接著,(y,x) 是視窗在底層區域內的角落位置 (例如左上角,視您解讀空間維度的做法而定)。我們現在有一個從基礎區域擷取的 2D 視窗,其中每個 2D 點都與 1D 向量相關聯,因此我們會取得 3D 方塊。從卷積核心中,我們已修正輸出座標 z,因此還有 3D 框。這兩個盒子的尺寸相同,因此我們可以擷取兩個方塊之間元素的元素總和 (類似內積)。也就是輸出值。

請注意,如果 output-z 為 5,則視窗的每個位置會在輸出內容中產生 5 個值,並輸出至輸出內容的 z 維度。這些值的差異在於所使用的卷積核部分,每個 output-z 座標都會使用不同的 3D 值方塊。因此,您可以將其視為 5 個獨立的卷積,每個卷積都有不同的濾鏡。

以下是 2D 卷積的虛擬程式碼,搭配邊框間距和間距:

for (b, oz, oy, ox) {  // output coordinates
  value = 0;
  for (iz, ky, kx) {  // kernel coordinates and input z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if ((iy, ix) inside the base area considered without padding) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

另請參閱 XlaBuilder::ConvertElementType

與 C++ 中依元素的 static_cast 類似,執行從資料形狀到目標形狀的元素轉換運算。維度必須相符,且轉換是元素的影響元素;例如,s32 元素會透過 s32f32 的轉換常式成為 f32 元素。

ConvertElementType(operand, new_element_type)

引數 類型 語義學
operand XlaOp 具有維度 D 的類型 T 陣列
new_element_type PrimitiveType 類型 U

運算元的維度與目標形狀必須相符。來源和目的地元素類型不得為元組。

T=s32U=f32 這類轉換會執行標準化 int 到 float 的轉換例程,例如四捨五入。

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

使用加總運算執行 AllReduce

CustomCall

另請參閱 XlaBuilder::CustomCall

在運算中呼叫使用者提供的函式。

CustomCall(target_name, args..., shape)

引數 類型 語義學
target_name string 函式名稱。系統會發出以此符號名稱為目標的呼叫指示。
args N 個 XlaOp 的序列 任意類型的 N 個引數,會傳遞至函式。
shape Shape 函式的輸出形狀

無論 args 的類型或類型為何,函式簽名都相同:

extern "C" void target_name(void* out, void** in);

舉例來說,如果 CustomCall 的用途如下:

let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };

CustomCall("myfunc", {x, y}, f32[3x3])

以下是 myfunc 實作範例:

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

使用者提供的函式不得有副作用,且執行方式必須是冪等的。

Dot

另請參閱 XlaBuilder::Dot

Dot(lhs, rhs)

引數 類型 語義學
lhs XlaOp T 類型的陣列
rhs XlaOp 類型為 T 的陣列

這項作業的確切語意取決於運算元的排名:

輸入 輸出 語義學
向量 [n] dot 向量 [n] 純量 向量點積
矩陣 [m x k] dot 向量 [k] 向量 [m] 矩陣-向量相乘
矩陣 [m x k] dot 矩陣 [k x n] 矩陣 [m x n] 矩陣-矩陣相乘

此運算會在 lhs 的第二個維度 (或第一個維度,如果其等級為 1) 和 rhs 的第一個維度上,執行乘積的總和。這些是「約定」維度。lhsrhs 的縮減維度必須相同。在實際應用中,可用於在向量之間執行內積、向量/矩陣相乘或矩陣/矩陣相乘。

DotGeneral

另請參閱 XlaBuilder::DotGeneral

DotGeneral(lhs, rhs, dimension_numbers)

引數 類型 語義學
lhs XlaOp T 類型的陣列
rhs XlaOp 類型為 T 的陣列
dimension_numbers DotDimensionNumbers 收縮和批次維度數

與 Dot 類似,但允許 lhsrhs 指定合約及批次維度號碼。

DotDimensionsNumbers 欄位 類型 語義學
lhs_contracting_dimensions repeated int64 lhs 收縮維度數字
rhs_contracting_dimensions 重複 int64 rhs 包量尺寸
lhs_batch_dimensions repeated int64 lhs 個批次維度編號
rhs_batch_dimensions 重複 int64 rhs 批次維度編號

DotGeneral 會在 dimension_numbers 中指定的收縮維度上執行乘積和運算。

lhsrhs 中關聯的合約維度編號不必相同,但尺寸必須相同。

合約維度編號範例:

lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }

lhsrhs 的關聯批量維度編號必須具有相同的維度大小。

以下是包含批次維度數字的範例 (批次大小 2,2x2 矩陣):

lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }

rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
輸入 輸出 語義學
[b0, m, k] dot [b0, k, n] [b0、m、n] 批次矩陣乘法
[b0, b1, m, k] dot [b0, b1, k, n] [b0、b1、m、n] 批次 matmul

因此,產生的維度編號會以批次維度開頭,接著是 lhs 非收縮/非批次維度,最後是 rhs 非收縮/非批次維度。

DynamicSlice

另請參閱 XlaBuilder::DynamicSlice

DynamicSlice 會從動態 start_indices 的輸入陣列中擷取子陣列。每個維度中的配量大小都會透過 size_indices 傳遞,進而指定每個維度中專屬切片間隔的終點:[開始、開始 + 大小]。start_indices 的形狀必須排名 == 1,尺寸等於 operand 的排名。

DynamicSlice(operand, start_indices, size_indices)

引數 類型 語義學
operand XlaOp 類型為 T 的 N 維陣列
start_indices N XlaOp 序列 N 純量整數清單,內含每個維度切片的起始索引。值必須大於或等於 0。
size_indices ArraySlice<int64> 列出 N 整數清單,其中包含每個維度的區塊大小。每個值都必須大於 0,且開頭 + 大小必須小於或等於維度大小,以免換行的模數大小。

有效切片索引的計算方式為,在執行切片之前,針對 [1, N) 中的每個索引 i 套用下列轉換:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])

這麼做可確保擷取的切片一律在運算子陣列的邊界內。如果切片在套用轉換前已在邊界內,轉換就不會生效。

1 維範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}

2D 範例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0,  8.0},
{10.0, 11.0} }

DynamicUpdateSlice

另請參閱 XlaBuilder::DynamicUpdateSlice

DynamicUpdateSlice 會產生結果,也就是輸入陣列 operand 的值,其中 update 會在 start_indices 上覆寫。update 的形狀會決定要更新的結果子陣列形狀。start_indices 的形狀必須是 rank == 1,且維度大小等於 operand 的 rank。

DynamicUpdateSlice(operand, update, start_indices)

引數 類型 語義學
operand XlaOp 類型為 T 的 N 維陣列
update XlaOp 包含切片更新的 T 型 N 維陣列。更新形狀的每個維度都必須大於零,且開始 + 更新必須小於或等於每個維度的運算元大小,以免產生超出邊界更新索引。
start_indices N 個 XlaOp 的序列 N 個單點整數清單,其中包含每個維度的切片起始索引。值必須大於或等於 0。

系統會先對 [1, N) 中的每個索引 i 套用下列轉換,再執行配量,藉此計算有效片段的索引值:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])

這可確保更新後的切片一律會在運算子陣列的邊界內。如果配量在套用轉換之前是邊界,轉換就不會產生任何作用。

1 維範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}

2D 範例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0,  13.0},
{14.0,  15.0},
{16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
{ {0.0,  1.0,  2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }

元素級別二進位算術運算

另請參閱 XlaBuilder::Add

系統支援一組元素的二進位算術運算。

Op(lhs, rhs)

其中 OpAdd (加總)、Sub(分項)、Mul (乘分)、Div (乘除)、Pow (乘)、Rem (最大範圍)、Max (最大)、Min (最小)、And (邏輯 AND)、Or (邏輯 OR)、Xor (邏輯 XOR)、ShiftLeft (邏輯 XOR)、向右/向右 ShiftRightArithmetic (邏輯 1、右側 / 右側位位 ShiftRightArithmetic)、實位第 1 和 ShiftRightArithmetic (邏輯 1、右側 / 對數ShiftRightLogicalAtan2Complex

引數 類型 語義學
lhs XlaOp 左側運算元:型別為 T 的陣列
rhs XlaOp 右側運算元:類型為 T 的陣列

引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容性的意義。作業的結果具有形狀,這是廣播兩個輸入陣列的結果。在這個變化版本中,系統「不」支援不同階層陣列之間的運算,除非其中一個運算元是純量。

OpRem 時,結果的符號會取自除數,且結果的絕對值一律小於除數的絕對值。

整數除法溢位 (帶正負號/未簽署的除數/餘數為零,或具有 -1INT_SMIN 已簽署除數/剩餘數) 會產生實作定義的值。

針對下列作業,有支援不同等級廣播的替代變體:

Op(lhs, rhs, broadcast_dimensions)

其中 Op 與上述相同。這項運算的變化版本應用於不同階層陣列之間的算術運算 (例如將矩陣加到向量)。

額外的 broadcast_dimensions 運算元是整數切片,用於將較低階運算元的階級擴展至較高階運算元的階級。broadcast_dimensions 會將較低階形狀的維度對應至較高階形狀的維度。展開形狀中未對應的維度會填入大小為 1 的維度。然後沿著這些退化維度廣播形狀,以便將兩個運算元的形狀均等化。廣播頁面會詳細說明語意。

元素比較運算

另請參閱 XlaBuilder::Eq

支援一組標準元素級別二元比較運算。請注意,比較浮點類型時,適用標準 IEEE 754 浮點比較語意。

Op(lhs, rhs)

其中 OpEq (等於)、Ne (不等於)、Ge (大於或等於)、Gt (大於)、Le (小於或等於)、Lt (小於)。另一組運算子 (EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder) 提供相同功能,但它們還支援浮點數的總順序,方法是強制執行 -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN。

引數 類型 語義學
lhs XlaOp 左側運算元:型別為 T 的陣列
rhs XlaOp 右側運算元:類型為 T 的陣列

引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容性的意義。作業的結果具有形狀,這是以元素類型 PRED 廣播兩個輸入陣列的結果。在這個變化版本中,系統支援不同階層陣列之間的運算,除非其中一個運算元是純量。

下列作業有提供不同排名廣播支援的替代變化版本:

Op(lhs, rhs, broadcast_dimensions)

其中 Op 與上述相同。該作業的變體應該用於比較不同排名陣列之間的作業 (例如在向量中新增矩陣)。

額外的 broadcast_dimensions 運算元是整數片段,用於指定用於廣播運算元的維度。廣播頁面會詳細說明語意。

元素級別的單一函式

XlaBuilder 支援以下元素級一元函式:

Abs(operand) 元素為 abs x -> |x|

Cbrt(operand) 元素層級立方根運算 x -> cbrt(x)

Ceil(operand) 元素層級 - x -> ⌈x⌉

Clz(operand) 逐元素計算前置零。

Cos(operand) 元素逐元素餘弦 x -> cos(x)

Erf(operand) 元素級別誤差函數 x -> erf(x),其中

\(\text{erf}(x) = \frac{2}{\sqrt{\pi} }\int_0^x e^{-t^2} \, dt\)。

Exp(operand) 元素逐元素自然指數 x -> e^x

Expm1(operand) 元素逐元素自然指數減一 x -> e^x - 1

Floor(operand) 元素逐元素的底層 x -> ⌊x⌋

Imag(operand) 複雜 (或實) 形狀的元素逐元素虛部。x -> imag(x)。如果運算元是浮點類型,則會傳回 0。

IsFinite(operand) 會測試 operand 的每個元素是否有限,也就是說,不是正數或負無限大,也不是 NaN。傳回 PRED 值的陣列,其形狀與輸入值相同,其中每個元素都是 true,前提是相應的輸入元素必須是有限的。

Log(operand) 元素逐元素自然對數 x -> ln(x)

Log1p(operand) 依據元素移動到自然對數 x -> ln(1+x)

Logistic(operand) 元素邏輯邏輯函式計算 x -> logistic(x)

Neg(operand) 元素逐元素否定 x -> -x

Not(operand) 元素逐元素邏輯否定 x -> !(x)

PopulationCount(operand) 計算 operand 中每個元素設定的位元數。

Real(operand) 以元素為中心,在複雜 (或實際) 形狀中。 x -> real(x)。如果運算元是浮點類型,則會傳回相同的值。

Round(operand) 元素方向無條件進位,相差不遠。

RoundNearestEven(operand) 依元素無條件進位,會連結至最接近的偶數。

Rsqrt(operand) 平方根運算 x -> 1.0 / sqrt(x) 的元素順序倒數。

Sign(operand) 元素級別符號運算 x -> sgn(x),其中

\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]

方法是使用 operand 元素的比較運算子。

Sin(operand) 元素位置的正弦 x -> sin(x)

Sqrt(operand) 元素級別平方根運算 x -> sqrt(x)

Tan(operand) 元素逐元素的切線 x -> tan(x)

Tanh(operand) 元素逐元素雙曲正切 x -> tanh(x)

引數 類型 語義學
operand XlaOp 函式的運算元

函式會套用至 operand 陣列中的每個元素,產生形狀相同的陣列。可以使用 operand 做為純量 (排名 0)。

Fft

XLA FFT 運算會針對實數和複數輸入/輸出,實作正向和反向的傅立葉變換。支援最多 3 個軸的多維 FFT。

另請參閱 XlaBuilder::Fft

引數 類型 語義學
operand XlaOp 我們是 Fourier 轉換的陣列。
fft_type FftType 請參閱下表。
fft_length ArraySlice<int64> 要轉換的軸的時間域長度。由於 RFFT(fft_length=[16]) 的輸出形狀與 RFFT(fft_length=[17]) 相同,因此 IRFFT 必須將它調整至最內側的軸大小。
FftType 語義學
FFT 正向複雜-複雜 FFT。形狀不變。
IFFT 反向複雜到複雜 FFT。形狀不變。
RFFT 將真實轉向的 FFT 模型。如果 fft_length[-1] 為非零值,則最內軸的形狀會縮減為 fft_length[-1] // 2 + 1,省略 Nyquist 頻率以外的轉換信號的反相共軛部分。
IRFFT 反向至複雜的 FFT (亦即複雜且傳回實境)。如果 fft_length[-1] 為非零值,則最內軸的形狀會展開為 fft_length[-1],從 1fft_length[-1] // 2 + 1 項目的反元素推斷超出 Nyquist 頻率的轉換信號。

多維 FFT

如果提供多個 fft_length,這相當於將 FFT 作業的串聯套用至每個內軸的每個軸。請注意,對於 real->complex 和 complex->real 情況,最內側的軸轉換會 (實際上) 優先執行 (RFFT;IRFFT 則為最後),因此最內側的軸會變更大小。其他軸轉換作業則會是 complex->complex。

實作詳情

CPU FFT 由 Eigen 的 TensorFFT 提供支援。GPU FFT 會使用 cuFFT。

Gather

XLA 收集運算會將輸入陣列的多個切片 (每個切片的執行時間偏移值可能不同) 拼接在一起。

一般語意

另請參閱 XlaBuilder::Gather。如需更直觀的說明,請參閱下方的「非正式說明」一節。

gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)

引數 類型 語義學
operand XlaOp 我們要從中收集資料的陣列。
start_indices XlaOp 這個陣列含有我們收集的配量起始索引。
index_vector_dim int64 start_indices 中「包含」起始索引的維度。詳情請參閱下文。
offset_dims ArraySlice<int64> 輸出形狀中的維度集合,會偏移至從運算子切割的陣列。
slice_sizes ArraySlice<int64> slice_sizes[i] 是維度 i 中切片的邊界。
collapsed_slice_dims ArraySlice<int64> 每個切片中已摺疊的維度組合。這些維度的大小必須是 1。
start_index_map ArraySlice<int64> 這張對應表說明如何將 start_indices 中的索引對應至運算元的有效索引。
indices_are_sorted bool 是否保證索引會由呼叫端排序。

為方便起見,我們將輸出陣列中未在 offset_dims 中標示的維度標示為 batch_dims

輸出內容為 batch_dims.size + offset_dims.size 等級的陣列。

operand.rank 必須等於 offset_dims.sizecollapsed_slice_dims.size 的總和。此外,slice_sizes.size 必須等於 operand.rank

如果 index_vector_dim 等於 start_indices.rank,我們會隱含地將 start_indices 視為具有尾隨 1 維度的形狀 (也就是說,如果 start_indices 的形狀為 [6,7],而 index_vector_dim2,則我們會隱含地將 start_indices 的形狀視為 [6,7,1])。

沿著維度 i 計算輸出陣列的邊界如下:

  1. 如果 i 出現在 batch_dims 中 (也就是說,對於某些 ki 等於 batch_dims[k]),我們會從 start_indices.shape 中挑選對應的維度邊界,略過 index_vector_dim (也就是說,如果 k < index_vector_dim,請選取 start_indices.shape.dims[k];否則選取 start_indices.shape.dims[k+1])。

  2. 如果 ioffset_dims 中 (也就是部分 koffset_dims[k]),則在計算 collapsed_slice_dims 後,我們會挑選 slice_sizes 的對應邊界 (即選擇 adjusted_slice_sizes[k],其中 adjusted_slice_sizesslice_sizes,且已移除索引 collapsed_slice_dims 的邊界)。

正式來說,對應給定輸出索引 Out 的運算元索引 In 的計算方式如下:

  1. G = { Out[k] for k in batch_dims }。使用 G 切片向量 S,使 S[i] = start_indices[Combine(G, i)],其中 Combine(A, b) 會將 b 插入 A 的 index_vector_dim 位置。請注意,即使 G 為空白,也已定義明確:如果 G 為空白,則 S = start_indices

  2. 使用 start_index_mapS 分批為 S,使用 S 將起始索引 Sin 建立為 operand。具體來說:

    1. Sin[start_index_map[k]] = S[k] (如果 k < start_index_map.size)。

    2. Sin[_] = 0,否則為 0

  3. 根據 collapsed_slice_dims 集合,在 Out 中的偏移維度中散布索引,藉此在 operand 中建立索引 Oin。具體來說:

    1. Oin[remapped_offset_dims(k)] = Out[offset_dims[k]] 如果 k < offset_dims.size 已定義,remapped_offset_dims

    2. Oin[_] = 0,否則

  4. InOin + Sin,其中 + 是元素相加。

remapped_offset_dims 是一個單調函式,網域為 [0, offset_dims.size),範圍為 [0, operand.rank) \ collapsed_slice_dims。因此,如果offset_dims.size4operand.rank6collapsed_slice_dims 為 {02},remapped_offset_dims 為 {01132435}。

如果將 indices_are_sorted 設為 true,XLA 可以假設 start_indices 已由使用者排序 (以遞增順序排列,根據 start_index_map 散布其值之後)。如果不是,就會定義語意。

非正式的說明與範例

非正式地說,輸出陣列中的每個索引 Out 都對應至運算子陣列中的元素 E,計算方式如下:

  • 我們會使用 Out 中的批次維度,從 start_indices 查詢起始索引。

  • 我們使用 start_index_map 將起始索引 (大小可能小於 operand.rank) 對應至 operand 中的「完整」起始索引。

  • 我們會使用完整的起始索引,動態切出大小為 slice_sizes 的切片。

  • 我們會收合 collapsed_slice_dims 維度,重新塑形片段。由於所有已摺疊的切片維度都必須有 1 個邊界,因此這個重塑作業一律合法。

  • 我們使用 Out 中的偏移維度為這個切片建立索引,取得輸入元素 E (對應至輸出索引 Out)。

在後續所有範例中,index_vector_dim 都設為 start_indices.rank - 1index_vector_dim 的其他有趣值不會從根本上改變運算,但會讓視覺呈現更加繁瑣。

為了瞭解上述所有條件如何搭配運作,以下舉例說明從 [16,11] 陣列收集 5 個 [8,6] 切片的範例。切片至 [16,11] 陣列的位置可以表示為 S64[2] 形狀的索引向量,因此 5 個位置組合可以表示為 S64[5,2] 陣列。

接著,收集作業的行為可視為索引轉換,該轉換會取得 [G,O0,O1],也就是輸出形狀中的索引,並以以下方式將其對應至輸入陣列中的元素:

我們會先使用 G 從收集索引陣列中選取 (X,Y) 向量。輸出陣列中索引 [G,O0,O1] 的元素,就是輸入陣列中索引 [X+O0,Y+O1] 的元素。

slice_sizes[8,6],可決定 O0 和 O1 的範圍,進而決定切片的邊界。

這項收集作業可做為批次動態切片使用 G 做為批次維度。

收集的索引可能是多維度的。舉例來說,在上述使用「gather indices」形狀陣列 [4,5,2] 的較通用版範例中,系統會轉譯索引,如下所示:

同樣地,這也會做為批次動態切片 G0G1 的批次維度。配量大小仍為 [8,6]

XLA 中的收集運算會以以下方式概略上述非正式語意:

  1. 我們可以設定輸出形狀中的哪些維度是偏移維度 (最後一個範例包含 O0O1 的維度)。輸出批次維度 (包含 G0 的維度,在上一例中為 G1) 定義為非偏移維度的輸出維度。

  2. 輸出形狀中明確顯示的輸出偏移維度數量可能少於輸入排名。這些「缺少」的維度 (明確列為 collapsed_slice_dims) 必須具有 1 的切片大小。由於這些元素的切片大小為 1,因此唯一有效的索引為 0,而省略這些元素不會造成歧義。

  3. 從「Gather Indices」陣列 (上一個範例中的 (XY)) 擷取的切片,可能比輸入陣列的等級少了一些元素,而明確的對應關係會決定如何擴充索引,以便與輸入內容的等級相同。

最後一個範例,我們使用 (2) 和 (3) 實作 tf.gather_nd

G0G1 用於從收集索引陣列中切出起始索引,這與平常一樣,但起始索引只有一個元素 X。同樣地,只有一個輸出偏移索引,其值為 O0。不過,在用於輸入陣列的索引之前,這些索引會根據「Gather Index Mapping」(正式說明中的 start_index_map) 和「Offset Mapping」(正式說明中的 remapped_offset_dims) 擴展為 [X,0] 和 [0,O0],總計為 [X,O0]。換句話說,輸出索引 [G0,G1,O0] 會對應至輸入索引 [GatherIndices[G0,G1,0],O0],這會為我們提供 tf.gather_nd 的語意。

本例的 slice_sizes[1,11]。直覺上表示,收集索引陣列中的每個索引 X 都會選取整個資料列,而結果是所有這些資料列的串連。

GetDimensionSize

另請參閱 XlaBuilder::GetDimensionSize

傳回運算元指定維度的大小。運算元必須是陣列形狀。

GetDimensionSize(operand, dimension)

引數 類型 語義學
operand XlaOp n 維度輸入陣列
dimension int64 指定維度的 [0, n) 間隔值

SetDimensionSize

另請參閱 XlaBuilder::SetDimensionSize

設定 XlaOp 指定維度的動態大小。運算元必須是陣列形狀。

SetDimensionSize(operand, size, dimension)

引數 類型 語義學
operand XlaOp n 維輸入陣列。
size XlaOp int32,代表執行階段動態大小。
dimension int64 指定維度的 [0, n) 間隔值。

傳遞運算元結果,其中包含由編譯器追蹤的動態維度。

下游的縮減運算作業會忽略填補的值。

let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;

// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);

// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);

// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);

GetTupleElement

另請參閱 XlaBuilder::GetTupleElement

將索引轉換為具有編譯時間常數值的元組。

值必須是編譯時間常數,這樣形狀推論才能判斷產生值的類型。

這類似於 C++ 中的 std::get<int N>(t)。概念上:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.

另請參閱 tf.tuple

Infeed

另請參閱 XlaBuilder::Infeed

Infeed(shape)

引數 類型 語義學
shape Shape 從 Infeed 介面讀取的資料形狀。形狀的版面配置欄位必須設為與傳送至裝置的資料版面配置相符,否則其行為未定義。

從裝置的隱含動態饋給串流介面讀取單一資料項目,將資料解讀為指定形狀及其版面配置,並傳回資料的 XlaOp。計算中允許有多個中繼操作,但中繼操作之間必須有總順序。舉例來說,下方程式碼中的兩個 Infeed 具有總順序,因為 while 迴圈之間存在依賴關係。

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

系統不支援巢狀元組形狀。如果是空的元組形狀,Infeed 作業實際上是無作業,且不會從裝置的 Infeed 讀取任何資料。

器皿打擊樂

另請參閱 XlaBuilder::Iota

Iota(shape, iota_dimension)

在裝置上建構常數字面值,而非可能龐大的主機轉移。建立具有指定形狀的陣列,並保留從零開始,並在指定維度上依一個遞增值遞增的值。如果是浮點類型,產生的陣列等同於 ConvertElementType(Iota(...)),其中 Iota 是積分類型,且轉換為浮點類型。

引數 類型 語義學
shape Shape Iota() 建立的陣列形狀
iota_dimension int64 要遞增的維度。

舉例來說,Iota(s32[4, 8], 0) 會傳回

  [[0, 0, 0, 0, 0, 0, 0, 0 ],
   [1, 1, 1, 1, 1, 1, 1, 1 ],
   [2, 2, 2, 2, 2, 2, 2, 2 ],
   [3, 3, 3, 3, 3, 3, 3, 3 ]]

可退貨 (費用:Iota(s32[4, 8], 1))

  [[0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ]]

地圖

另請參閱 XlaBuilder::Map

Map(operands..., computation)

引數 類型 語義學
operands N 個 XlaOp 的序列 N 個類型 T 的陣列,其中 T 為 0..T{N-1}
computation XlaComputation T_0, T_1, .., T_{N + M -1} -> S 型別的運算,其中 N 個參數為 T 型別,M 為任意型別
dimensions int64 陣列 地圖維度的陣列

對指定的 operands 陣列套用標量函式,產生相同維度的陣列,其中每個元素都是將對應函式套用至輸入陣列中的對應元素所產生的結果。

對應函式是具有限制的任意運算,其中具有 N 個純量類型 T 的輸入內容,以及一項類型為 S 的輸出。輸出內容的維度與運算元相同,唯一差別在於元素類型 T 會替換為 S。

例如:Map(op1, op2, op3, computation, par1) 會將 elem_out <- computation(elem1, elem2, elem3, par1) 對應至輸入陣列中的每個 (多維) 索引,以產生輸出陣列。

OptimizationBarrier

禁止任何最佳化過程移動跨越障礙。

確保系統會先評估所有輸入的運算子,然後再評估依附於任何運算子輸出的運算子。

熱敷墊

另請參閱 XlaBuilder::Pad

Pad(operand, padding_value, padding_config)

引數 類型 語義學
operand XlaOp T 類型的陣列
padding_value XlaOp T 型別的純量,用於填入新增的邊框
padding_config PaddingConfig 兩側邊框間距 (低、高) 和各維度元素之間的間距

透過在陣列周圍和陣列元素之間使用指定 padding_value 的邊框間距,擴展指定的 operand 陣列。padding_config 會指定每個維度的邊框間距和內部間距。

PaddingConfigPaddingConfigDimension 的重複欄位,其中包含每個維度的三個欄位:edge_padding_lowedge_padding_highinterior_padding

edge_padding_lowedge_padding_high 分別指定在各維度的低端 (靠近索引 0) 和高端 (靠近最高索引) 新增的邊框間距量。邊緣邊框的邊框寬度可以為負值,負邊框的絕對值表示從指定維度移除的元素數量。

interior_padding 會指定在每個維度的任兩個元素之間加入的邊框間距量,且不得為負值。內部邊框間距在邏輯上會出現在邊框間距之前,因此在邊框間距為負值的情況下,系統會從內部邊框間距運算元中移除元素。

如果邊緣邊框組合全部為 (0, 0),且內部邊框值全部為 0,則此作業會是無操作。下圖顯示了二維陣列的不同 edge_paddinginterior_padding 值範例。

Recv

另請參閱 XlaBuilder::Recv

Recv(shape, channel_handle)

引數 類型 語義學
shape Shape 要接收的資料形狀
channel_handle ChannelHandle 每個傳送/接收組合的專屬 ID

在共用相同管道句柄的其他運算中,從 Send 指令接收指定形狀的資料。傳回已接收資料的 XlaOp。

Recv 作業的用戶端 API 代表同步通訊。不過,這項指令會在內部分解為 2 個 HLO 指令 (RecvRecvDone),以便啟用非同步資料傳輸。另請參閱 HloInstruction::CreateRecvHloInstruction::CreateRecvDone

Recv(const Shape& shape, int64 channel_id)

分配從具有相同 channel_id 的 Send 指令接收資料所需的資源。傳回分配資源的結構定義,下列 RecvDone 指令會使用該資源等待資料移轉完成。這個內容是 {接收緩衝區 (形狀)、要求 ID (U32)} 的元組,且只能由 RecvDone 指令使用。

RecvDone(HloInstruction context)

假設有由 Recv 指令建立的結構定義,系統會等待資料移轉完成並傳回接收的資料。

遏止

另請參閱 XlaBuilder::Reduce

將縮減函式平行套用到一或多個陣列。

Reduce(operands..., init_values..., computation, dimensions)

引數 類型 語義學
operands N 個 XlaOp 的序列 類型為 T_0, ..., T_{N-1} 的 N 個陣列。
init_values N 個 XlaOp 的序列 T_0, ..., T_{N-1} 類型的 N 純量。
computation XlaComputation 類型為 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 的運算。
dimensions int64 陣列 要縮減的維度無序陣列。

在此情況下:

  • N 必須大於或等於 1。
  • 計算必須「大致」具有結合性 (請參閱下文)。
  • 所有輸入陣列的維度都必須相同。
  • 所有初始值都必須在 computation 下形成身分。
  • 如果是 N = 1Collate(T) 就是 T
  • 如果為 N > 1Collate(T_0, ..., T_{N-1})T 類型的 N 元素組合。

這個運算會將每個輸入陣列的一或多個維度縮減為標量。每個傳回陣列的排名為 rank(operand) - len(dimensions)。此運算子的輸出值為 Collate(Q_0, ..., Q_N),其中 Q_iT_i 類型的陣列,其維度如下所述。

允許不同的後端重新連結縮減運算。這可能會導致數值差異,因為某些減法函式 (例如加法) 不支援浮點運算。不過,如果資料範圍有限,浮點加法在大多數實際用途上就足以達到關聯性。

範例

使用 [10, 11, 12, 13] 值的縮減函式 f (為 computation) 減少單一 1D 陣列中的一個維度時,可視為以下計算結果:

f(10, f(11, f(12, f(init_value, 13)))

但還有許多其他可能性,例如

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))

以下舉例說明如何實作縮減作業,使用加總做為初始值為 0 的縮減運算。

result_shape <- remove all dims in dimensions from operand_shape

# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # Initialize this result element
  result[r0, r1...] <- 0

  # Iterate over all the reduction dimensions
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # Increment the result element with the value of the operand's element.
    # The index of the operand's element is constructed from all ri's and di's
    # in the right order (by construction ri's and di's together index over the
    # whole operand shape).
    result[r0, r1...] += operand[ri... di]

以下是 2D 陣列 (矩陣) 縮減的範例。形狀的等級是 2、維度 0,尺寸為 3:

使用「add」函式減少維度 0 或 1 的結果:

請注意,兩個縮減結果都是 1D 陣列。為方便視覺起見,圖中顯示一個資料欄和另一個資料列。

以下是 3D 陣列的較複雜範例。其排名為 3,維度 0 的大小為 4,維度 1 的大小為 2,維度 2 的大小為 3。為求簡單,系統會在維度 0 之間複製值 1 到 6。

與 2D 範例類似,我們只能減少一個維度。舉例來說,如果我們縮小維度 0,就會得到排名 2 陣列,其中所有維度 0 的值都折疊成純量:

|  4   8  12 |
| 16  20  24 |

如果我們縮減第 2 個維度,也會得到第 2 個維度的所有值折疊成標量,並取得第 2 個維度的陣列:

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

請注意,輸入內容中其他維度之間的相對順序會保留在輸出內容中,但部分維度可能會指派新的編號 (因為排名會變更)。

此外,也可以減少多個維度。新增並減少維度 0 和 1,會產生 1D 陣列 [20, 28, 36]

將所有維度縮減 3D 陣列會產生純量 84

減肥

N > 1 時,reduce 函式應用程序會同時套用至所有輸入內容,因此會稍微複雜一些。運算子會以以下順序提供給運算:

  • 為第一個運算元執行減值
  • ...
  • 執行第 N 個運算元的減數值
  • 第一個運算元的輸入值
  • ...
  • 第 N 個運算元的輸入值

舉例來說,請考慮下列縮減函式,可用於平行計算 1 維陣列的最大值和 argmax:

f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
  if value >= max:
    return (value, index)
  else:
    return (max, argmax)

對於 1 維輸入陣列 V = Float[N], K = Int[N] 和初始值 I_V = Float, I_K = Int,在單一輸入維度中縮減的結果 f_(N-1) 等同於以下遞迴應用程式:

f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))

將這項縮減作業套用至值陣列和序列索引陣列 (即 iota),會同時對陣列進行迭代,並傳回包含最大值和相符索引的元組。

ReducePrecision

另請參閱 XlaBuilder::ReducePrecision

建立將浮點值轉換為較低精確度格式 (例如 IEEE-FP16) 及恢復原始格式的效果。您可以任意指定較低精確度格式的指數和形容詞位元數,但並非所有硬體實作項目都支援所有位元大小。

ReducePrecision(operand, mantissa_bits, exponent_bits)

引數 類型 語義學
operand XlaOp 浮點類型 T 的陣列。
exponent_bits int32 較低精度格式中的指數位元數
mantissa_bits int32 低精確度格式的 mantissa 位元數

結果為 T 類型的陣列。輸入值會四捨五入至最接近的值,可使用指定的尾數位元數 (使用「ties to even」語意),任何超過指數位元數所指定範圍的值都會被箝制為正無窮或負無窮。NaN 值會保留,但可能會轉換為標準 NaN 值。

較低精確度的格式必須至少包含一個指數位元 (為了區分零值和無窮大,因為兩者都具有零尾數位元),且必須包含非負數的尾數位元位元。指數或尾數位元數量可能超過類型 T 的對應值;轉換的對應部分就會變成無操作。

ReduceScatter

另請參閱 XlaBuilder::ReduceScatter

ReduceScatter 是集體運算,可有效執行 AllReduce,然後沿著 scatter_dimension 將結果分割成 shard_count 區塊,並在複本群組中接收 ith 資料分割的複本 i

ReduceScatter(operand, computation, scatter_dim, shard_count, replica_group_ids, channel_id)

引數 類型 語義學
operand XlaOp 陣列或陣列的非空白元組,以便在備用資源間縮減。
computation XlaComputation 減法運算
scatter_dimension int64 要散布繪製的維度。
shard_count int64 要分割的scatter_dimension區塊數量
replica_groups int64 向量的向量 要執行減法運算的群組
channel_id 選填 int64 用於跨模組通訊的選用管道 ID
  • operand 是陣列的元組時,就會針對元組的每個元素執行減少散佈器。
  • replica_groups 是執行縮減作業的備份群組清單 (可使用 ReplicaId 擷取目前備份的備份 ID)。每個群組中的備份順序,決定了全縮減結果的散布順序。replica_groups 必須為空白 (在這種情況下,所有備援資料都屬於單一群組),或包含與備援資料數量相同的元素。如果有多個複本群組,則所有群組的大小都必須相同。舉例來說,replica_groups = {0, 2}, {1, 3} 會在複本 02 之間,以及 13 之間執行縮減作業,然後散布結果。
  • shard_count 是每個備用資源群組的大小。在 replica_groups 為空白的情況下,我們需要這個值。如果 replica_groups 非空白,shard_count 必須等於每個備援群組的大小。
  • channel_id 用於跨模組通訊:只有使用相同 channel_idreduce-scatter 作業才能相互通訊。

輸出形狀為輸入形狀,scatter_dimension縮小shard_count倍。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.25][3.0, 5.25] 值,則這個運算子的輸出值 (scatter_dim0) 將是第一個副本的 [4.0],第二個副本的 [7.5]

ReduceWindow

另請參閱 XlaBuilder::ReduceWindow

將縮減函式套用至 N 個多維陣列序列的每個視窗中的所有元素,產生單一或 N 個多維陣列的元組做為輸出。每個輸出陣列的元素數量都與視窗的有效位置數量相同。匯集層可以表示為 ReduceWindow。與 Reduce 類似,套用的 computation 一律會在左側傳遞 init_values

ReduceWindow(operands..., init_values..., computation, window_dimensions, window_strides, padding)

引數 類型 語義學
operands N XlaOps 一系列 T_0,..., T_{N-1} 型別的 N 個多維陣列,每個陣列都代表窗口放置的基礎區域。
init_values N XlaOps 運算的 N 個起始值,每個運算元式各一個。詳情請參閱「調降」一節。
computation XlaComputation T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的縮減函式,會套用至所有輸入運算元每個窗口中的元素。
window_dimensions ArraySlice<int64> 視窗維度值的整數陣列
window_strides ArraySlice<int64> 窗邊值的整數陣列
base_dilations ArraySlice<int64> 用於基本放大值的整數陣列
window_dilations ArraySlice<int64> 用於窗口擴大值的整數陣列
padding Padding 視窗的邊框類型 (Padding::kSame,如果步幅為 1,則會填充邊框,使輸出形狀與輸入形狀相同;或 Padding::kValid,不使用邊框,且在視窗無法再填充時「停止」)

在此情況下:

  • N 必須大於或等於 1。
  • 所有輸入陣列的尺寸都必須相同。
  • 如果 N = 1Collate(T)T
  • 如果是 N > 1Collate(T_0, ..., T_{N-1}) 就是 (T0,...T{N-1}) 類型的 N 元素元組。

以下程式碼和圖表顯示 ReduceWindow 的使用範例。輸入內容是大小為 [4x6] 的矩陣,且 window_dimensions 和 window_stride_dimensions 都是 [2x3]。

// Create a computation for the reduction (maximum).
XlaComputation max;
{
  XlaBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().value();
}

// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    *max,
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

在維度中,步幅為 1 表示維度中視窗的位置與相鄰視窗相距 1 個元素。為指定不重疊的視窗,window_stride_dimensions 應等於 window_dimensions。下圖說明了使用兩個不同的步幅值。填充會套用至輸入內容的每個維度,而計算結果與輸入內容填充後的維度相同。

對於非簡單的填充範例,請考慮在輸入陣列 [10000, 1000, 100, 10, 1] 上,使用維度 3 和步幅 2 計算 reduce-window 最小值 (初始值為 MAX_FLOAT)。Padding kValid 會在兩個有效視窗 ([10000, 1000, 100][100, 10, 1]) 中計算最小值,並產生輸出 [100, 1]。Padding kSame 會先為陣列填入邊框,藉此在兩側新增初始元素,讓縮減視窗後的形狀與第 1 步的輸入相同,進而取得 [MAX_VALUE, 10000, 1000, 100, 10, 1, MAX_VALUE]。對填充陣列執行縮減期的作業是三個視窗 [MAX_VALUE, 10000, 1000][1000, 100, 10][10, 1, MAX_VALUE] 和收益 [1000, 10, 1]

縮減函式的評估順序為任意順序,且可能是非決定性的。因此,縮減函式不應對重新關聯過度敏感。詳情請參閱 Reduce 情境中有關關聯性的討論。

ReplicaId

另請參閱 XlaBuilder::ReplicaId

傳回備用資源的專屬 ID (U32 純量)。

ReplicaId()

每個備份的專屬 ID 是 [0, N) 範圍內的無符號整數,其中 N 是備份數量。由於所有副本都執行相同的程式,因此程式中的 ReplicaId() 呼叫會在每個副本上傳回不同的值。

Reshape

另請參閱 XlaBuilder::ReshapeCollapse 作業。

將陣列的維度重新調整為新設定。

Reshape(operand, new_sizes) Reshape(operand, dimensions, new_sizes)

引數 類型 語義學
operand XlaOp 類型為 T 的陣列
dimensions int64 向量 收合維度的順序
new_sizes int64 向量 新尺寸的向量

從概念上來說,重新調整形狀會先將陣列扁平化為資料值的一維向量,然後將這個向量精緻化為新形狀。輸入引數是任意型別 T 的陣列、編譯時間常數向量的維度索引,以及結果的編譯時間常數向量維度大小。如果提供 dimension 向量中的值,則必須是 T 所有維度的排列;如果未提供,預設值為 {0, ..., rank - 1}dimensions 中的維度順序是從迴圈巢狀結構中變化最慢的維度 (最主要) 到變化最快的維度 (最次要),這個迴圈巢狀結構會將輸入陣列折疊為單一維度。new_sizes 向量會決定輸出陣列的大小。new_sizes 中索引 0 的值是維度 0 的大小,索引 1 的值則是維度 1 的大小,以此類推。new_size 維度的乘積必須等於運算元的維度大小乘積。當您將已摺疊的陣列精緻化為 new_sizes 定義的多維陣列時,new_sizes 中的維度會依變化速度由慢到快排序 (最主要) 和 (最次要)。

例如,讓 v 為 24 個元素的陣列:

let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
                    { {20, 21, 22}, {25, 26, 27} },
                    { {30, 31, 32}, {35, 36, 37} },
                    { {40, 41, 42}, {45, 46, 47} } };

In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47} };

Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};

let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
                          {31, 41, 12}, {22, 32, 42},
                          {15, 25, 35}, {45, 16, 26},
                          {36, 46, 17}, {27, 37, 47} };


let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
                              {11, 21}, {31, 41},
                              {12, 22}, {32, 42} },
                             { {15, 25}, {35, 45},
                              {16, 26}, {36, 46},
                              {17, 27}, {37, 47} } };

在特殊情況下,reshape 可將單一元素陣列轉換為標量,反之亦然。例如:

Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };

Rev (反轉)

另請參閱 XlaBuilder::Rev

Rev(operand, dimensions)

引數 類型 語義學
operand XlaOp 類型為 T 的陣列
dimensions ArraySlice<int64> 要反轉的維度

沿著指定的 dimensions 反轉 operand 陣列中的元素順序,產生相同形狀的輸出陣列。多維度索引運算元陣列的每個元素,都會儲存在轉換索引的輸出陣列中。會反轉每個維度中的索引即可轉換多維度索引 (也就是說,如果大小 N 的維度是逆向維度之一,其索引 i 會轉換為 N - 1 - i)。

Rev 運算的一種用途,是在類神經網路的梯度計算期間,沿著兩個視窗維度反轉卷積權重陣列。

RngNormal

另請參閱 XlaBuilder::RngNormal

使用在 \(N(\mu, \sigma)\) 常態分佈後產生的隨機數字,建構指定形狀的輸出。參數 \(\mu\) 和 \(\sigma\),以及輸出形狀必須具有浮點元素類型。參數進一步須是純量值。

RngNormal(mu, sigma, shape)

引數 類型 語義學
mu XlaOp 指定產生數字平均值的 T 型純量
sigma XlaOp T 類型的純量,用於指定產生的標準差
shape Shape 輸出型別 T 的形狀

RngUniform

另請參閱 XlaBuilder::RngUniform

根據在區間 \([a,b)\)內均勻分布的隨機號碼,建構指定形狀的輸出內容。參數和輸出元素類型必須是布林值類型、積分類型或浮點類型,且類型必須一致。CPU 和 GPU 後端目前僅支援 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,參數必須是純量值。如果 \(b <= a\) 為結果,則由實作定義。

RngUniform(a, b, shape)

引數 類型 語義學
a XlaOp T 類型的純量,指定間隔下限
b XlaOp 指定間隔上限的 T 型別向量
shape Shape 輸出型別 T 的形狀

RngBitGenerator

使用指定的演算法 (或後端預設),產生以特定形狀填入均勻隨機位元的輸出內容,並傳回更新的狀態 (與初始狀態相同的形狀) 和產生的隨機資料。

初始狀態是目前隨機號碼產生的初始狀態。所需的形狀和有效值取決於使用的演算法。

輸出內容保證可做為初始狀態的確定性函式,但「不保證」在後端和不同編譯器版本之間具有確定性。

RngBitGenerator(algorithm, key, shape)

引數 類型 語義學
algorithm RandomAlgorithm 要使用的 PRNG 演算法。
initial_state XlaOp PRNG 演算法的初始狀態。
shape Shape 產生資料的輸出形狀。

algorithm 的可用值:

散布圖

XLA 散布運算會產生一系列結果,這些結果是輸入陣列 operands 的值,其中使用 update_computation 更新 updates 中的值序列,並以 scatter_indices 指定的索引切片。

另請參閱 XlaBuilder::Scatter

scatter(operands..., scatter_indices, updates..., update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

引數 類型 語義學
operands N 個 XlaOp 的序列 要散布到 T_0, ..., T_N 類型的 N 個陣列。
scatter_indices XlaOp 陣列,其中包含必須散發至的切片起始索引。
updates N 個 XlaOp 的序列 類型為 T_0, ..., T_N 的 N 個陣列。updates[i] 包含散佈 operands[i] 時必須使用的值。
update_computation XlaComputation 用於結合輸入陣列中現有值和散布期間更新的運算。此計算類型應為 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)
index_vector_dim int64 scatter_indices 中包含起始索引的維度。
update_window_dims ArraySlice<int64> updates 形狀中的視窗大小維度集。
inserted_window_dims ArraySlice<int64> 必須插入 updates 形狀的視窗大小集。
scatter_dims_to_operand_dims ArraySlice<int64> 從分散索引到運算元索引空間的維度。這個陣列會解讀為將 i 對應至 scatter_dims_to_operand_dims[i]。必須是 1 對 1 的關係,且必須是總和。
indices_are_sorted bool 指出索引是否保證由呼叫端排序。
unique_indices bool 呼叫端是否保證索引不重複。

在此情況下:

  • N 必須大於或等於 1。
  • operands[0]、...、operands[N-1] 必須具有相同的尺寸。
  • updates[0]、...、updates[N-1] 的尺寸都必須相同。
  • 如果是 N = 1Collate(T) 就是 T
  • 如果為 N > 1Collate(T_0, ..., T_N)T 類型的 N 元素的元組。

如果 index_vector_dim 等於 scatter_indices.rank,我們隱性地會將 scatter_indices 視為結尾的 1 維度。

我們將 ArraySlice<int64> 類型的 update_scatter_dims 定義為 updates 形狀中不屬於 update_window_dims 的一組維度,並以升冪順序排列。

分散的引數應遵循以下限制:

  • 每個 updates 陣列都必須是 update_window_dims.size + scatter_indices.rank - 1 等級。

  • 每個 updates 陣列中維度 i 的邊界必須符合下列條件:

    • 如果 i 出現在 update_window_dims 中 (即等於某些 kupdate_window_dims[k]),則在考量 inserted_window_dims 後,updates 中的維度 i 邊界不得超過 operand 的對應邊界 (即 adjusted_window_bounds[k],其中 adjusted_window_bounds 包含 operand 的邊界,但已移除索引 inserted_window_dims 的邊界)。
    • 如果 iupdate_scatter_dims 中 (也就是部分 kupdate_scatter_dims[k]),updates 中的維度 i 邊界必須與 scatter_indices 的對應邊界相同,並略過 index_vector_dim (如果 k < index_vector_dimscatter_indices.shape.dims[k+1],則略過 index_vector_dim)。kscatter_indices.shape.dims
  • update_window_dims 必須以遞增順序排列,且不得重複維度編號,且必須在 [0, updates.rank) 的範圍內。

  • inserted_window_dims 必須依遞增順序排列,不含任何重複維度數字,而且必須位於 [0, operand.rank) 範圍內。

  • operand.rank 必須等於 update_window_dims.sizeinserted_window_dims.size 的總和。

  • scatter_dims_to_operand_dims.size 必須等於 scatter_indices.shape.dims[index_vector_dim],且其值必須在 [0, operand.rank) 範圍內。

針對每個 updates 陣列中的指定索引 U,在要套用更新的相應 operands 陣列中,對應索引 I 的計算方法如下:

  1. G = {U[k] for k in update_scatter_dims}。使用 Gscatter_indices 陣列中查詢索引向量 S,讓 S[i] = scatter_indices[Combine(G, i)],其中 Combine(A, b) 會將 b 插入 A 的 index_vector_dim 位置。
  2. 使用 scatter_dims_to_operand_dims 地圖散布 S,藉此使用 S 將索引 Sin 建立至 operand。更正式:
    1. Sin[scatter_dims_to_operand_dims[k]] = S[k],如果 k < scatter_dims_to_operand_dims.size
    2. Sin[_] = 0,否則為 0
  3. 根據 inserted_window_dims 將索引散布在 Uupdate_window_dims 中,藉此在每個 operands 陣列中建立索引 Win。更正式:
    1. Win[window_dims_to_operand_dims(k)] = U[k] (如果 k 位於 update_window_dims,其中 window_dims_to_operand_dims 是各網域 [0, update_window_dims.size) 和範圍 [0, operand.rank) \ inserted_window_dims 的單調函式]。(舉例來說,如果 update_window_dims.size4operand.rank6,而 inserted_window_dims 是 {0, 2},則 window_dims_to_operand_dims 為 {01132435})。
    2. Win[_] = 0,否則為 0
  4. IWin + Sin,其中 + 是元素相加。

總結來說,散布運算可定義如下:

  • 使用 operands 初始化 output,也就是針對所有 J 索引,以及 operands[J] 陣列中的所有 O 索引:
    output[J][O] = operands[J][O]
  • 針對 updates[J] 陣列中的每個索引 U,以及 operand[J] 陣列中的對應索引 O,如果 Ooutput 的有效索引:
    (output[0][O]、...、output[N-1][O]) =update_computation(output[0][O]、...、output[N-1][O]、updates[0][U]、...、updates[N-1][U])

更新套用順序不確定。因此,當 updates 中的多個索引參照 operands 中的相同索引時,output 中的對應值將非確定性。

請注意,傳遞至 update_computation 的第一個參數一律為 output 陣列的目前值,而第二個參數一律為 updates 陣列的值。這點對於 update_computation 非可交換的情況尤其重要。

如果將 indices_are_sorted 設為 true,XLA 可以假設 scatter_indices 是根據使用者 scatter_dims_to_operand_dims 來遞增其值「之後」以遞增順序排序。如果不是,則語意是由實作定義。

如果將 unique_indices 設為 true,XLA 可以假設分批的所有元素都不重複。因此 XLA 可以使用非原子作業。如果 unique_indices 設為 true,且散布的索引並非唯一,則語意是由實作方式定義。

非正式地說,散布運算可視為收集運算的反向運算,也就是散布運算會更新由對應收集運算擷取的輸入內容元素。

如需詳細的非正式說明和範例,請參閱 Gather 下方的「非正式說明」一節。

選取

另請參閱 XlaBuilder::Select

根據謂詞陣列的值,從兩個輸入陣列的元素建構輸出陣列。

Select(pred, on_true, on_false)

引數 類型 語義學
pred XlaOp PRED 類型的陣列
on_true XlaOp T 類型的陣列
on_false XlaOp 類型為 T 的陣列

陣列 on_trueon_false 的形狀必須相同。這也是輸出陣列的形狀。陣列 pred 必須具有與 on_trueon_false 相同的維度,並使用 PRED 元素類型。

對於 pred 的每個元素 P,如果 P 的值為 true,則輸出陣列的對應元素會從 on_true 取得;如果 P 的值為 false,則會從 on_false 取得。pred廣播的限制形式,可以是 PRED 類型的標量。在這種情況下,如果 predtrue,輸出陣列會完全從 on_true 取得;如果 predfalse,則會從 on_false 取得。

非純量 pred 的範例:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

純量 pred 的範例:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

支援在元組之間進行選取。為此,元組屬於純量類型。如果 on_trueon_false 為元組 (其形狀必須相同!),則 pred 必須是 PRED 類型的純量。

SelectAndScatter

另請參閱 XlaBuilder::SelectAndScatter

這項運算可視為複合運算,首先會在 operand 陣列上計算 ReduceWindow,從每個視窗中選取一個元素,然後將 source 陣列散布至所選元素的索引,以建構與運算子陣列相同形狀的輸出陣列。二進位 select 函式可用於在每個視窗中套用,從中選取元素,並以第一個參數的索引向量在字典順序上小於第二個參數的索引向量為條件來呼叫。如果選取第一個參數,select 函式會傳回 true,如果選取第二個參數,則會傳回 false,且函式必須具備傳遞性 (如果 select(a, b)select(b, c)true,則 select(a, c) 也為 true),以便所選元素不依賴特定視窗中經過的元素順序。

函式 scatter 會套用至輸出陣列中的每個所選索引。它會使用兩個純量參數:

  1. 輸出陣列中所選索引的目前值
  2. 適用於所選索引的 source 分批值

這個方法會合併兩個參數並傳回純量值,用來更新輸出陣列中所選索引的值。輸出陣列的所有索引一開始都是設為 init_value

輸出陣列的形狀與 operand 陣列相同,而 source 陣列的形狀必須與在 operand 陣列上套用 ReduceWindow 運算的結果相同。SelectAndScatter 可用於在神經網路中,為匯集層回傳梯度值。

SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)

引數 類型 語義學
operand XlaOp 系統在其中滑動視窗的 T 型陣列
select XlaComputation 類型為 T, T -> PRED 的二進位運算,可套用至每個視窗中的所有元素;如果選取第一個參數,則會傳回 true,如果選取第二個參數,則會傳回 false
window_dimensions ArraySlice<int64> 視窗維度值的整數陣列
window_strides ArraySlice<int64> 視窗步距值的整數陣列
padding Padding 視窗的邊框間距類型 (Padding::kSame 或 Padding::kValid)
source XlaOp 類型 T 的陣列,有要分批的值
init_value XlaOp 輸出陣列初始值的 T 類型純量值
scatter XlaComputation T, T -> T 類型的二進位運算,以便套用每個散佈來源元素及其目的地元素

下圖顯示使用 SelectAndScatter 的範例,其中的 select 函式會計算其參數中的最大值。請注意,當時段重疊 (如下圖 (2) 所示) 時,可能會透過不同視窗多次選取 operand 陣列的索引。在圖中,兩個頂端視窗 (藍色和紅色) 都選取了值為 9 的元素,而二進位加法 scatter 函式會產生值為 8 (2 + 6) 的輸出元素。

scatter 函式的評估順序為任意順序,且可能是非決定性的。因此,scatter 函式不應過度敏感到重新建立關聯,如需詳細資訊,請參閱 Reduce 的關聯性討論內容。

傳送

另請參閱 XlaBuilder::Send

Send(operand, channel_handle)

引數 類型 語義學
operand XlaOp 要傳送的資料 (T 類型陣列)
channel_handle ChannelHandle 每個傳送/接收組合的專屬 ID

將指定的運算元資料傳送至共用相同管道句柄的另一個運算中 Recv 指令。不會傳回任何資料。

Recv 作業類似,Send 作業的用戶端 API 代表同步通訊,並在內部分解為 2 個 HLO 指令 (SendSendDone),以啟用非同步資料移轉。另請參閱 HloInstruction::CreateSendHloInstruction::CreateSendDone

Send(HloInstruction operand, int64 channel_id)

使用相同的管道 ID,將運算元非同步轉移至 Recv 指令所分配的資源。傳回結構定義,下列 SendDone 指令將使用等待資料移轉完成。這個內容是 {運算子 (形狀)、要求 ID (U32)} 的元組,且只能由 SendDone 指令使用。

SendDone(HloInstruction context)

Send 指令建立的內容中,等待資料傳輸完成。這項指令不會傳回任何資料。

頻道指示排程

每個管道的 4 個指令的執行順序如下:RecvRecvDoneSendSendDone

  • Recv 會在 Send 之前發生
  • Send 會在 RecvDone 之前發生
  • Recv發生在 RecvDone之前
  • Send 會在 SendDone 之前發生

當後端編譯器為透過通道指令通訊的每個運算產生線性排程時,運算作業中就不能有週期。舉例來說,下列排程會導致死結。

配量

另請參閱 XlaBuilder::Slice

切片會從輸入陣列中擷取子陣列。子陣列與輸入的排名相同,且包含輸入陣列中的定界框值,其中定界框的維度和索引會做為切片運算的引數。

Slice(operand, start_indices, limit_indices, strides)

引數 類型 語義學
operand XlaOp 類型為 T 的 N 維陣列
start_indices ArraySlice<int64> 包含 N 個整數的清單,其中包含每個維度的切片起始索引。值必須大於或等於零。
limit_indices ArraySlice<int64> 包含 N 個整數的清單,其中包含每個維度切片的結束索引 (不含)。每個值都必須大於或等於維度的個別 start_indices 值,且小於或等於維度的大小。
strides ArraySlice<int64> 這裡列出的 N 個整數,用來決定切片的輸入步距。此片段會挑選尺寸 d 中的所有 strides[d] 元素。

1D 範例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

2D 範例:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

排序

另請參閱 XlaBuilder::Sort

Sort(operands, comparator, dimension, is_stable)

引數 類型 語義學
operands ArraySlice<XlaOp> 要排序的運算元。
comparator XlaComputation 要使用的比較器運算。
dimension int64 排序的維度。
is_stable bool 是否應使用穩定排序。

如果只提供一個運算元:

  • 如果運算元是秩 1 張量 (陣列),結果會是排序陣列。如果您想將陣列排序為遞增順序,比較器應執行小於比較。正式來說,陣列排序後,會為所有索引位置 i, j 保留 i < j,而 i < jcomparator(value[i], value[j]) = comparator(value[j], value[i]) = falsecomparator(value[i], value[j]) = true

  • 如果運算元具有較高的排名,則會依照提供的維度排序運算元。舉例來說,如果是秩為 2 的張量 (矩陣),0 維度值會獨立排序每個資料欄,而 1 維度值會獨立排序每個資料列。如未提供維度編號,系統會預設選擇最後一個維度。對於排序的維度,會套用與排名 1 相同的排序順序。

如果提供 n > 1 運算元:

  • 所有 n 運算元都必須是相同維度的張量。張量的元素類型可能不同。

  • 所有運算元皆會一起排序,不會個別排序。從概念上來說,運算元式會視為元組。在檢查索引位置 ij 的每個運算元的元素是否需要交換時,會使用 2 * n 標量參數呼叫比較器,其中參數 2 * k 對應 k-th 運算元的 i 位置,而參數 2 * k + 1 對應 k-th 運算元的 j 位置。一般情況下,比較子會比較參數 2 * k2 * k + 1,並可能會使用其他參數組合做為串連的斷路器。

  • 結果是元組,其中包含以排序順序排列的運算元 (如上所述,沿著提供的維度)。元組的 i-th 運算元對應於 Sort 的 i-th 運算元。

舉例來說,如果有三個運算元 operand0 = [3, 1]operand1 = [42, 50]operand2 = [-3.0, 1.1],且比較器只比較 operand0 的值,並以小於運算,則排序的輸出內容就是元組 ([1, 3], [50, 42], [1.1, -3.0])

如果 is_stable 設為 true,系統會保證排序穩定,也就是說,如果比較器認為某些元素相等,則會保留相等值的相對順序。只有在 comparator(e1, e2) = comparator(e2, e1) = false 的情況下,兩個元素 e1e2 才會相等。根據預設,is_stable 會設為 false。

轉置

另請參閱 tf.reshape 運算。

Transpose(operand)

引數 類型 語義學
operand XlaOp 要轉置的運算元。
permutation ArraySlice<int64> 如何排列維度。

使用指定的排列法對運算元維度進行排列,因此為 ∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]

這與 Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) 相同。

TriangularSolve

另請參閱 XlaBuilder::TriangularSolve

用正向或後代換解具有底角係數矩陣的線性方程式系統。這個例行程式會沿著前導維度進行廣播,在 ab 的情況下,為變數 x 解決其中一個矩陣系統 op(a) * x = bx * op(a) = b,其中 op(a)op(a) = aop(a) = Transpose(a)op(a) = Conj(Transpose(a))

TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)

引數 類型 語義學
a XlaOp 陣列的階層為 2,且為形狀為 [..., M, M] 的複數或浮點型別。
b XlaOp 如果 left_side 為 true,則為 rank > 2 個相同類型的陣列,形狀為 [..., M, K];否則為 [..., K, M]
left_side bool 表示要解決 op(a) * x = b (true) 或 x * op(a) = b (false) 形式的系統。
lower bool 是否使用 a 的上三角或下三角。
unit_diagonal bool 如果 true,則 a 的對角元素會假設為 1,且不會存取。
transpose_a Transpose 是否要使用 a 原樣、轉置或取其共軛轉置。

輸入資料只會從 a 的下/上三角讀取,具體取決於 lower 的值。系統會忽略其他三角形的值。輸出資料會在同一三角形中傳回;其他三角形中的值是由實作定義,可能為任何值。

如果 ab 的階層大於 2,系統會將其視為矩陣的批次,其中除了次要 2 個維度以外,所有都是批次維度。ab 的批次維度必須相同。

元組

另請參閱 XlaBuilder::Tuple

一個元組,其中包含可變數數量的資料句柄,每個句柄都有各自的形狀。

這類似於 C++ 中的 std::tuple。概念上:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

您可以透過 GetTupleElement 運算來解構 (存取) 元組。

While

另請參閱 XlaBuilder::While

While(condition, body, init)

引數 類型 語義學
condition XlaComputation 定義迴圈結束條件的 T -> PRED 類型 XlaComputation。
body XlaComputation 定義迴圈主體的 T -> T 類型 XlaComputation。
init T conditionbody 參數的初始值。

依序執行 body,直到 condition 失敗為止。這類似於以其他語言進行迴圈,但下列差異和限制除外。

  • While 節點會傳回 T 類型的值,這是 body 上次執行的結果。
  • T 類型的形狀是以靜態方式決定,且在所有疊代作業中都必須相同。

計算作業的 T 參數會在第一個迴迭中使用 init 值初始化,並在後續每個迴迭中自動更新為 body 的新結果。

While 節點的主要用途之一,是實作在神經網路中重複執行訓練的功能。下方列出簡化的虛擬程式碼,以及代表運算的圖表。您可以在 while_test.cc 中找到這段程式碼。此範例中的 T 類型為 Tuple,其中包含用於疊代計數的 int32,以及用於累加器的 vector[10]。在 1000 次迭代中,迴圈會持續將常數向量加到累加器。

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}