以下說明 XlaBuilder
介面中定義的作業語意。一般而言,這些作業會將一對一對應至 xla_data.proto
中 RPC 介面定義的作業。
命名法注意事項:XA 處理的概略資料類型是含有某種統一類型的 N 陣列陣列 (例如 32 位元浮點值)。在本說明文件中,陣列用於表示任意維度陣列。為方便起見,特殊情況會有更具體且熟悉的名稱。例如「向量」是 1 維陣列,「矩陣」則是 2D 陣列。
AfterAll
另請參閱 XlaBuilder::AfterAll
。
AfterAll 會採用各種不同的符記並產生單一符記。符記是原始類型,可在連帶效果作業之間建立執行緒以強制執行排序。AfterAll
可做為在一組作業後用來訂購作業的權杖彙整。
AfterAll(operands)
引數 | 類型 | 語義 |
---|---|---|
operands |
XlaOp |
符記數量變化 |
AllGather
另請參閱 XlaBuilder::AllGather
。
這個外掛程式能跨備用資源執行串連作業。
AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
channel_id)
引數 | 類型 | 語義 |
---|---|---|
operand
|
XlaOp
|
跨備用資源串連的陣列 |
all_gather_dim |
int64 |
串連維度 |
replica_groups
|
int64 向量的向量 |
執行串連作業的群組 |
channel_id
|
選用 int64
|
跨模組通訊的選用頻道 ID |
replica_groups
是執行串連的備用資源群組清單 (您可以使用ReplicaId
擷取目前備用資源的備用資源 ID)。每個群組中的備用資源順序會決定其輸入內容在結果中的顯示順序。replica_groups
必須為空白 (在這種情況下,所有備用資源都屬於單一群組,由0
到N - 1
排序),或包含與備用資源數量相同的元素數量。例如,replica_groups = {0, 2}, {1, 3}
會在備用資源0
與2
以及1
與3
之間執行串連作業。shard_count
是每個備用資源群組的大小。如果replica_groups
為空白,就必須使用這個參數。channel_id
用於跨模組通訊:只有具有相同channel_id
的all-gather
作業才能相互通訊。
輸出形狀是輸入形狀,all_gather_dim
使 shard_count
倍變大。舉例來說,如果有兩個備用資源,而運算元在兩個備用資源中分別具有 [1.0, 2.5]
和 [3.0, 5.25]
的值,則這項運算的輸出值,其中 all_gather_dim
為 0
在兩個備用資源中為 [1.0, 2.5, 3.0,
5.25]
。
AllReduce
另請參閱 XlaBuilder::AllReduce
。
在多項備用資源間執行自訂運算。
AllReduce(operand, computation, replica_group_ids, channel_id)
引數 | 類型 | 語義 |
---|---|---|
operand
|
XlaOp
|
使用陣列或非空白的陣列 在備用資源之間縮減 |
computation |
XlaComputation |
縮減運算 |
replica_groups
|
int64 向量的向量 |
執行縮減作業的群組 |
channel_id
|
選用 int64
|
跨模組通訊的選用頻道 ID |
- 當
operand
是陣列的組合時,系統會對元組的每個元素執行所有簡化作業。 replica_groups
是執行縮減程序的備用資源群組清單 (您可以使用ReplicaId
擷取目前備用資源的備用資源 ID)。replica_groups
必須空白 (假如所有備用資源都屬於單一群組),或包含與備用資源數量相同的元素數量。舉例來說,replica_groups = {0, 2}, {1, 3}
會在備用資源0
與2
,以及1
和3
之間縮減。channel_id
用於跨模組通訊:只有具有相同channel_id
的all-reduce
作業才能相互通訊。
輸出形狀與輸入形狀相同。舉例來說,如果有兩個備用資源,且運算元在兩個備用資源上分別具有 [1.0, 2.5]
和 [3.0, 5.25]
值,則這兩個備用資源上的這項運算和加總計算的輸出值會是 [4.0, 7.75]
。如果輸入內容為元組,則輸出結果也會是元組。
計算 AllReduce
的結果時,每個備用資源需要有一項輸入,因此若其中一個備用資源執行 AllReduce
節點的次數超過另一個節點,則先前的備用資源會一直等待。由於備用資源全都執行相同的程式,因此不會有許多方式可進行,但當迴圈的條件取決於饋入的資料,且加載的資料導致時迴圈疊代作業對一個備用資源進行疊代作業的次數會多於另一個備用資源。
AllToAll
另請參閱 XlaBuilder::AllToAll
。
AllToAll 是一種集體作業,會將所有核心的資料傳送至所有核心。其中包含兩個階段:
- 散佈階段。在每個核心上,運算元會沿著
split_dimensions
分割成split_count
區塊,而區塊會散佈至所有核心,例如將 i 區塊傳送到第 i 個核心。 - 收集階段。每個核心都會沿著
concat_dimension
串連收到的區塊。
您可以透過下列方式設定參與核心:
replica_groups
:每個 ReplicaGroup 都包含參與計算的備用資源 ID 清單 (可使用ReplicaId
擷取目前備用資源的備用資源 ID)。AllToAll 會依照指定的順序套用於子群組中。舉例來說,replica_groups = { {1,2,3}, {4,5,0} }
表示 AllToAll 會套用至備用資源{1, 2, 3}
和收集階段,且收到的區塊會按照 1、2、3 的順序串連。然後,另一個 AllToAll 會套用在備用資源 4、5、0 中,串連順序也是 4、5、0。如果replica_groups
為空白,表示所有備用資源都屬於一個群組,且按照其外觀的串連順序。
需求條件:
split_dimension
運算元的維度大小可由split_count
除去。- 運算元的形狀不是元組。
AllToAll(operand, split_dimension, concat_dimension, split_count,
replica_groups)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
N 維度輸入陣列 |
split_dimension
|
int64
|
間隔 [0,
n) 中的值,此值為維度命名,運算元會分割 |
concat_dimension
|
int64
|
間隔 [0,
n) 中的值,用來為維度命名,且該維度會串連 |
split_count
|
int64
|
參與這項作業的核心數量。如果 replica_groups 為空白,這應為備用資源的數量;否則,應等於每個群組中的備用資源數量。 |
replica_groups
|
ReplicaGroup 向量
|
每個群組都包含備用資源 ID 清單。 |
下圖是 Alltoall 的範例。
XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
在這個範例中,Alltoall 共有 4 個核心。在每個核心上,運算元會依照維度 0 分成 4 個部分,因此每個部分的形狀 f32[4,4]。這 4 個部分會散佈至所有核心。然後每個核心會按照維度 1 的順序,將接收到的零件依核心 0 至 4 的順序串連。因此每個核心的輸出內容都有 f32[16,4] 形狀。
BatchNormGrad
如需演算法的詳細說明,另請參閱 XlaBuilder::BatchNormGrad
和原始批次正規化文件。
計算批次常態的梯度。
BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維度陣列 (x) |
scale |
XlaOp |
1 個維陣列 (\(\gamma\)) |
mean |
XlaOp |
1 個維陣列 (\(\mu\)) |
variance |
XlaOp |
1 個維陣列 (\(\sigma^2\)) |
grad_output |
XlaOp |
已傳遞到 BatchNormTraining (\(\nabla y\)) 的漸層 |
epsilon |
float |
E 普西龍值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於地圖項目維度中的每項特徵 (feature_index
是 operand
中地圖項目維度的索引),運算會計算所有其他維度的漸層 (operand
、offset
和 scale
)。feature_index
必須是 operand
中地圖項目維度的有效索引。
這三個漸層是由下列公式定義 (假設 4 維陣列為 operand
,且具有特徵維度索引 l
、批次大小 m
,以及空間大小 w
和 h
):
\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]
輸入 mean
和 variance
代表批次和空間維度的時刻值。
輸出類型是由三個控點組成的元組:
輸出 | 類型 | 語義 |
---|---|---|
grad_operand
|
XlaOp
|
輸入 operand 的漸層 ($\nabla
x$) |
grad_scale
|
XlaOp
|
輸入 scale 的漸層 ($\nabla
\gamma$) |
grad_offset
|
XlaOp
|
輸入的漸層 offset ($\nabla
\beta$) |
BatchNormInference
如需演算法的詳細說明,另請參閱 XlaBuilder::BatchNormInference
和原始批次正規化文件。
將批次和空間維度的陣列正規化。
BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維度陣列 |
scale |
XlaOp |
1 個維度陣列 |
offset |
XlaOp |
1 個維度陣列 |
mean |
XlaOp |
1 個維度陣列 |
variance |
XlaOp |
1 個維度陣列 |
epsilon |
float |
E 普西龍值 |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於特徵維度中的每個地圖項目 (feature_index
是 operand
中地圖項目維度的索引),運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數,將 operand
中的每個元素正規化。feature_index
必須是 operand
中地圖項目維度的有效索引。
BatchNormInference
等同於呼叫 BatchNormTraining
,而不需要計算每個批次的 mean
和 variance
。而是使用輸入 mean
和 variance
做為預估值。此操作的目的是縮短推論延遲時間,因此名稱為 BatchNormInference
。
輸出內容是 n 維的正規化陣列,形狀與輸入 operand
相同。
BatchNormTraining
如需演算法的詳細說明,另請參閱 XlaBuilder::BatchNormTraining
和 the original batch normalization paper
。
將批次和空間維度的陣列正規化。
BatchNormTraining(operand, scale, offset, epsilon, feature_index)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維度陣列 (x) |
scale |
XlaOp |
1 個維陣列 (\(\gamma\)) |
offset |
XlaOp |
1 個維陣列 (\(\beta\)) |
epsilon |
float |
E 普西龍值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於特徵維度中的每個地圖項目 (feature_index
是 operand
中地圖項目維度的索引),運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數,將 operand
中的每個元素正規化。feature_index
必須是 operand
中地圖項目維度的有效索引。
在 operand
\(x\) 內,每個批次包含 m
元素,且空間維度大小為 w
和 h
(假設 operand
是 4D 陣列),則演算法如下所示:
計算特徵維度中每個特徵
l
\(\mu_l\) 的批次平均值: \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)計算批次變異數 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$
正規化、縮放及轉移: \(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)
會加入 Epsilon 值,通常為較小的數字,以免發生除數為零的錯誤。
輸出類型是三個 XlaOp
的元組:
輸出 | 類型 | 語義 |
---|---|---|
output
|
XlaOp
|
n 維度陣列,形狀與輸入 operand (y) 相同 |
batch_mean |
XlaOp |
1 個維陣列 (\(\mu\)) |
batch_var |
XlaOp |
1 個維陣列 (\(\sigma^2\)) |
batch_mean
和 batch_var
是使用上述公式,跨批次和空間維度計算的時刻。
BitcastConvertType
另請參閱 XlaBuilder::BitcastConvertType
。
與 TensorFlow 中的 tf.bitcast
類似,執行從資料形狀到目標形狀的元素順位點陣圖作業。輸入和輸出大小必須相符:例如,s32
元素會透過點陣圖處理常式成為 f32
元素,而一個 s32
元素會變為四個 s8
元素。Bitcast 以低階轉換的方式實作,因此具有不同浮點表示法的機器會產生不同的結果。
BitcastConvertType(operand, new_element_type)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 型別陣列,且會調暗 D |
new_element_type |
PrimitiveType |
類型 U |
運算元和目標形狀的維度必須相符,除了最後一個維度之外,最後一個維度會隨轉換前後的原始大小比例變化。
來源和目的地元素類型不得為元組。
將位元轉換轉換為不同寬度的原始類型
BitcastConvert
HLO 指令支援下列情況:如果輸出元素類型的大小 T'
不等於輸入元素 T
的大小。由於整個作業的概念為位元傳播,且不會變更基礎位元組,因此輸出元素的形狀必須改變。B = sizeof(T), B' =
sizeof(T')
可能會有兩種可能情況。
首先,在 B > B'
時,輸出形狀會取得幾乎一個大小的 B/B'
新尺寸。例如:
f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)
有效純量的規則維持不變:
f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)
此外,如果是 B' > B
,指示要求輸入形狀的最後一個邏輯維度等於 B'/B
,而且在轉換期間會捨棄這個維度:
f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)
請注意,不同位元寬度之間的轉換不會有元素。
廣播
另請參閱 XlaBuilder::Broadcast
。
複製陣列中的資料,藉此將維度加入陣列。
Broadcast(operand, broadcast_sizes)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要複製的陣列 |
broadcast_sizes |
ArraySlice<int64> |
新的尺寸 |
新維度會在左側插入;也就是說,如果 broadcast_sizes
包含值為 {a0, ..., aN}
且運算元形狀維度為 {b0, ..., bM}
,則輸出內容形狀會含有維度 {a0, ..., aN, b0, ..., bM}
。
新維度索引複製到運算元的副本,例如
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
舉例來說,如果 operand
是值為 2.0f
的純量 f32
,且 broadcast_sizes
為 {2, 3}
,則結果會是形狀為 f32[2, 3]
的陣列,而結果中的所有值都將為 2.0f
。
BroadcastInDim
另請參閱 XlaBuilder::BroadcastInDim
。
複製陣列中的資料,藉此展開陣列的大小和排名。
BroadcastInDim(operand, out_dim_size, broadcast_dimensions)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要複製的陣列 |
out_dim_size |
ArraySlice<int64> |
目標形狀的尺寸大小 |
broadcast_dimensions |
ArraySlice<int64> |
運算元形狀中各維度對應的維度 |
與廣播類似,但允許在任何位置新增維度,並將大小設為 1。
operand
會廣播到 out_dim_size
描述的形狀。broadcast_dimensions
會將 operand
的維度對應至目標形狀的尺寸,也就是說,運算元的第 1 個維度會對應至輸出形狀的 broadcast_dimension[i] 尺寸。operand
的尺寸必須包含 1,或是其大小必須與其對應的輸出形狀中的維度相同。剩餘的維度會填入大小為 1 的尺寸。取消佈建維度廣播,然後沿著這些去產生的維度播送,達到輸出形狀。如要進一步瞭解語意,請參閱廣播頁面。
撥打電話
另請參閱 XlaBuilder::Call
。
使用指定的引數叫用運算。
Call(computation, args...)
引數 | 類型 | 語義 |
---|---|---|
computation |
XlaComputation |
具有任意類型的 N 個參數的 T_0, T_1, ..., T_{N-1} -> S 類型計算 |
args |
N XlaOp 秒的順序 |
N 個任意類型的引數 |
args
的陣列和類型必須與 computation
的參數相符。不得包含 args
。
黑洞隊 (Cholesky)
另請參閱 XlaBuilder::Cholesky
。
計算一批對稱 (Hermitian) 正矩陣的 Cholesky 分解。
Cholesky(a, lower)
引數 | 類型 | 語義 |
---|---|---|
a |
XlaOp |
大於 2 的複雜或浮點類型的陣列。 |
lower |
bool |
是否要使用 a 的上三角形。 |
如果 lower
為 true
,會計算低三角形矩陣 l
,也就是 $a = l。l^T$。如果 lower
為 false
,則計算上三角形矩陣 u
,以使其\(a = u^T . u\)。
系統會根據 lower
的值,僅從 a
的上下三角形讀取輸入資料。系統會忽略其他三角形的值。輸出資料會以同一個三角形傳回,其他三角形的值則是由實作定義,可以是任何內容。
如果 a
的排名大於 2,系統會將 a
視為一批矩陣,其中除了次要 2 維度以外的所有維度都是批次維度。
如果 a
非對稱 (Hermitian) 絕對不對稱,結果是由實作定義。
夾式
另請參閱 XlaBuilder::Clamp
。
將運算元限制在最小值和最大值之間的範圍內。
Clamp(min, operand, max)
引數 | 類型 | 語義 |
---|---|---|
min |
XlaOp |
T 類型的陣列 |
operand |
XlaOp |
T 類型的陣列 |
max |
XlaOp |
T 類型的陣列 |
指定運算元和最小值和最大值時,如果運算元在最小值與最大值之間,則傳回該運算元;否則,如果運算元低於這個範圍,則會傳回最小值;如果運算元超過這個範圍,則傳回最大值。也就是 clamp(a, x, b) = min(max(a, x), b)
。
這三個陣列的形狀都必須相同。或者,由於廣播的限制形式,min
和/或 max
也可以是 T
類型的純量。
純量 min
和 max
的範例:
let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};
收合
另請參閱 XlaBuilder::Collapse
和 tf.reshape
作業。
將陣列的維度收合為一個維度。
Collapse(operand, dimensions)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的陣列 |
dimensions |
int64 個向量 |
T 維度順序排列組合而成 |
收合時,會以單一維度取代運算元維度的特定子集。輸入引數是 T 類型的任意陣列,以及維度索引的編譯時間常數向量。維度索引必須是按順序排列 (低至高維度數字),以及 T 維度的連續子集。因此,{0, 1, 2}、{0、1} 或 {1, 2} 都是有效維度集,但 {1, 0} 或 {0, 2} 則否。取而代之的是單一新維度,在與取代的尺寸順序相同的位置,新尺寸大小等於原始尺寸大小的乘積。dimensions
中的最低維度編號是迴圈巢狀結構中最慢的不同維度 (最主要),它會收合這些維度,而最高維度編號是最快的變化速度 (最輕微)。如果需要更一般的收合順序,請參閱 tf.reshape
運算子。
舉例來說,假設 v 是包含 24 個元素的陣列:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};
// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };
// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };
CollectivePermute
另請參閱 XlaBuilder::CollectivePermute
。
CollectivePermute 屬於集體作業,可跨備用資源收發資料。
CollectivePermute(operand, source_target_pairs)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
N 維度輸入陣列 |
source_target_pairs |
<int64, int64> 個向量 |
(source_repo_id、target_repo_id) 組合清單。針對每個組合,運算元會從來源備用資源傳送至目標備用資源。 |
請注意,source_target_pair
具有下列限制:
- 任意兩對組合的目標備用資源 ID 不得相同,而且來源備用資源 ID 不得相同。
- 如果備用資源 ID 不是任何配對中的目標,則該備用資源的輸出內容會是包含 0(與輸入相同形狀) 的張量。
串連
另請參閱 XlaBuilder::ConcatInDim
。
串連組合多個陣列運算元的陣列。陣列的排名與每個輸入陣列運算元 (排名必須相同) 相同,且包含引數的指定順序。
Concatenate(operands..., dimension)
引數 | 類型 | 語義 |
---|---|---|
operands |
N XlaOp 序列 |
類型 T 的 N 陣列,尺寸為 [L0, L1, ...]。需要 N >= 1。 |
dimension |
int64 |
間隔 [0, N) 中的值,用來命名要在 operands 之間串連的維度。 |
除了 dimension
以外,所有尺寸都必須相同。這是因為 XLA 不支援「堆疊」陣列。另請注意,排名-0 的值無法串連 (因為您無法為發生串連的維度命名)。
1D 範例:
Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}
2D 範例:
let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}
圖表:
須符合條件
另請參閱 XlaBuilder::Conditional
。
Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
引數 | 類型 | 語義 |
---|---|---|
pred |
XlaOp |
PRED 類型的純量 |
true_operand |
XlaOp |
「 \(T_0\)」類型的引數 |
true_computation |
XlaComputation |
\(T_0 \to S\)類型的 XlaComputation |
false_operand |
XlaOp |
「 \(T_1\)」類型的引數 |
false_computation |
XlaComputation |
\(T_1 \to S\)類型的 XlaComputation |
如果 pred
是 true
就執行 true_computation
,如果 pred
是 false
則執行 false_computation
並傳回結果。
true_computation
必須使用 \(T_0\) 類型的單一引數,並採用 true_operand
叫用,且目標類型必須相同。false_computation
必須使用 \(T_1\) 類型的單一引數,並會透過 false_operand
叫用,而這些引數必須屬於相同類型。回傳的 true_computation
和 false_computation
值類型必須相同。
請注意,視 pred
的值而定,系統只會執行 true_computation
和 false_computation
其中之一。
Conditional(branch_index, branch_computations, branch_operands)
引數 | 類型 | 語義 |
---|---|---|
branch_index |
XlaOp |
S32 類型的純量 |
branch_computations |
N XlaComputation 序列 |
\(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)類型的 XlaComputations |
branch_operands |
N XlaOp 序列 |
\(T_0 , T_1 , ..., T_{N-1}\)類型的引數 |
執行 branch_computations[branch_index]
,然後傳回結果。如果 branch_index
是 < 0 或 >= N 的 S32
,則 branch_computations[N-1]
會做為預設分支版本執行。
每個 branch_computations[b]
都必須使用一個型別的單一引數, \(T_b\) 並用 branch_operands[b]
叫用,因為必須屬於相同類型。每個 branch_computations[b]
的傳回值類型必須相同。
請注意,視 branch_index
的值而定,系統只會執行其中一個 branch_computations
。
轉換 (卷積)
另請參閱 XlaBuilder::Conv
。
如 ConvWithGeneralPadding,但邊框間距是以快速的方式指定,可為 SAME 或 VALID。SAME 的邊框間距會填補輸入 (lhs
) 並填入零,這樣在不考慮納入考量的情況下,輸出的形狀就會與輸入相同。VALID 邊框間距就是沒有邊框間距。
ConvWithGeneralPadding (對稱)
另請參閱 XlaBuilder::ConvWithGeneralPadding
。
計算類神經網路所用種類的捲積。這裡的捲積可視為在 n 維基本區域之間移動的 n 維視窗,且會對視窗的每個可能位置執行運算。
引數 | 類型 | 語義 |
---|---|---|
lhs |
XlaOp |
將輸入項目 n+2 陣列排名 |
rhs |
XlaOp |
對核心權重排名 n+2 陣列 |
window_strides |
ArraySlice<int64> |
n-d 陣列核心跨距 |
padding |
ArraySlice< pair<int64,int64>> |
n-d 陣列 (低、高) 邊框間距 |
lhs_dilation |
ArraySlice<int64> |
n-d lhs 除法係數陣列 |
rhs_dilation |
ArraySlice<int64> |
n-d Rhs 計算因數陣列 |
feature_group_count |
int64 | 特徵群組的數量 |
batch_group_count |
int64 | 批次群組的數量 |
n 是空間維度數量,lhs
引數是代表底區域的 n+2 陣列。雖然 Rhs 也是輸入內容,但這稱為「輸入」。在類神經網路中,這是指輸入啟動項目。n+2 維度的順序如下:
batch
:這個維度中的每個座標都代表執行卷積的獨立輸入。z/depth/features
:底區的每個 (y,x) 位置都有相關聯的向量,並納進這個維度。spatial_dims
:說明n
空間維度,以定義視窗移動的基本區域。
rhs
引數是排名 n+2 的陣列,用來說明卷積篩選器/核心/視窗。維度的順序如下:
output-z
:輸出內容的z
維度。input-z
:此維度的大小乘以feature_group_count
,應等於z
維度的大小 (以弧度為單位)。spatial_dims
:說明n
空間維度,用於定義在基本區域上移動的 n-d 視窗。
window_strides
引數會指定空間維度的捲積窗速。舉例來說,如果第一個空間維度的跨距為 3,則視窗只能放置在座標上,第一個空間索引可以用 3 除盡。
padding
引數會指定要套用至底區域的零邊框間距量。邊框間距數量可以是負數;負值邊框間距的絕對值代表在執行卷積之前,要從指定維度移除的元素數量。padding[0]
會指定 y
和 padding[1]
維度的邊框間距,藉此指定維度 x
的邊框間距。每個組合的邊框間距都是第一個元素,而高邊框間距是第二個元素。低邊框間距會以較低索引的方向套用,而高邊框間距則是根據較高索引的方向套用。舉例來說,如果 padding[1]
是 (2,3)
,則邊框間距在左側為 20,右側會有一個 3 0。使用邊框間距的做法,等同在輸入 (lhs
) 中插入相同的零值,再執行卷積。
lhs_dilation
和 rhs_dilation
引數會分別指定在每個空間維度中套用至 lh 和 rh 的除法因數。如果空間維度的除法係數是 d,則 d-1 孔會以隱含的方式放置在該維度中的每個項目之間,進而增加陣列的大小。孔洞會填滿免人工作業值,而卷積值代表零。
Rh 的進化也稱為群落卷積。詳情請參閱
tf.nn.atrous_conv2d
。拉鍊關係也稱為轉換式捲積。詳情請參閱 tf.nn.conv2d_transpose
。
feature_group_count
引數 (預設值 1) 可用於已分組的對話。feature_group_count
必須是輸入和輸出特徵維度的除數。如果 feature_group_count
大於 1,表示從概念上來說,輸入和輸出特徵維度和 rhs
輸出特徵維度會平均分割為多個 feature_group_count
群組,每個群組都是由連續的特徵子序列組成。rhs
的輸入特徵維度必須等於 lhs
輸入特徵維度除以 feature_group_count
(因此已經具有一組輸入特徵群組的大小)。i-th 群組會一起用於計算許多獨立卷積的 feature_group_count
。這些卷積的結果會在輸出特徵維度中串連在一起。
為了進行深度卷積,feature_group_count
引數會設為輸入特徵維度,而篩選器會從 [filter_height, filter_width, in_channels, channel_multiplier]
重設為 [filter_height, filter_width, 1, in_channels * channel_multiplier]
。詳情請參閱 tf.nn.depthwise_conv2d
。
batch_group_count
(預設值 1) 引數可在反向傳播期間用於分組的篩選器。batch_group_count
必須是 lhs
(輸入) 批次維度大小的除數。如果 batch_group_count
大於 1,表示輸出批量的尺寸應為 input batch
/ batch_group_count
。batch_group_count
必須是輸出特徵大小的除數。
輸出形狀的維度按照以下順序排列:
batch
:此維度的大小乘以batch_group_count
,應等於batch
維度的大小 (以弧度為單位)。z
:大小與核心上的output-z
相同 (rhs
)。spatial_dims
:每個卷積視窗有效位置的一個值。
上圖顯示 batch_group_count
欄位的運作方式。實際上,我們會將每個 lh 分批分為 batch_group_count
群組,並針對輸出特徵執行相同操作。接著,我們會針對每個群組執行配對卷積,並沿著輸出特徵維度串連輸出結果。所有其他維度 (特徵和空間) 的作業語意皆維持不變。
卷積視窗的有效位置取決於步距和填充後底區域的大小。
如果想說明卷積的作用,請考慮使用 2D 卷積,並在輸出內容中挑選一些固定的 batch
、z
、y
、x
座標。然後,(y,x)
是視窗底邊的一個位置 (例如,左上角取決於您所解讀空間維度的方式)。現在,我們有 2D 窗口,從底區域擷取,其中每個 2D 點都與一個 1D 向量相關聯,因此得到一個 3D 方塊。從卷積核心中,由於我們修正了輸出座標 z
,因此我們也有 3D 包裝盒。這兩個方塊的尺寸相同,因此可以將兩個方塊中元素相鄰的元素加總 (類似點積)。也就是輸出值。
請注意,如果 output-z
範例:5,每個視窗位置會在輸出內容中產生 5 個值,並放入輸出的 z
維度中。這些值在使用卷積核心的哪個部分不同,每個 output-z
座標都有使用另外的 3D 方塊。您可以將其視為 5 個不同的捲積,每個卷積都有不同篩選器。
以下是含邊框間距和小數的 2D 卷積的虛擬程式碼:
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
ConvertElementType
另請參閱 XlaBuilder::ConvertElementType
。
與 C++ 中元素導向的 static_cast
類似,執行從資料形狀到目標形狀的元素導向轉換作業。維度必須相符,且轉換是以元素為準,例如:s32
元素會透過 s32
到 f32
的轉換處理常式成為 f32
元素。
ConvertElementType(operand, new_element_type)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 型別陣列,且會調暗 D |
new_element_type |
PrimitiveType |
類型 U |
運算元的尺寸必須與目標形狀相符。來源和目的地元素類型不得為元組。
如 T=s32
到 U=f32
等轉換將會執行正規化的從浮點轉換常式轉換程序,例如從「接近平均」方向。
let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}
CrossReplicaSum
執行總和計算的 AllReduce
。
CustomCall
另請參閱 XlaBuilder::CustomCall
。
在計算中呼叫使用者提供的函式。
CustomCall(target_name, args..., shape)
引數 | 類型 | 語義 |
---|---|---|
target_name |
string |
函式的名稱。系統會發出呼叫指示,指定這個符號名稱。 |
args |
N XlaOp 秒的順序 |
任意類型的 N 個引數,系統會將這些引數傳遞至函式。 |
shape |
Shape |
函式的輸出形狀 |
無論引數類型或類型為何,函式簽章都相同:
extern "C" void target_name(void* out, void** in);
舉例來說,如果 CustomCall 用於下列方式:
let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };
CustomCall("myfunc", {x, y}, f32[3x3])
以下是 myfunc
的實作範例:
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
使用者提供的函式不得有副作用,且執行作業必須為冪等。
Dot
另請參閱 XlaBuilder::Dot
。
Dot(lhs, rhs)
引數 | 類型 | 語義 |
---|---|---|
lhs |
XlaOp |
T 類型的陣列 |
rhs |
XlaOp |
T 類型的陣列 |
這項作業的確切語意取決於運算元:
輸入 | 輸出 | 語義 |
---|---|---|
向量 [n] dot 向量 [n] |
純量 | 向量內積 |
矩陣 [m x k] dot 向量 [k] |
向量 [m] | 矩陣向量乘法 |
矩陣 [m x k] dot 矩陣 [k x n] |
矩陣 [m x n] | 矩陣矩陣乘法 |
這項作業會針對 lhs
的第二個維度 (排名第一,則為第一個維度) 和 rhs
的第一個維度執行產品總和。這就是「合約」的維度lhs
和 rhs
的約定尺寸必須相同。實際上,可用於執行向量、向量/矩陣乘法或矩陣/矩陣乘法之間的內積積點。
DotGeneral
另請參閱 XlaBuilder::DotGeneral
。
DotGeneral(lhs, rhs, dimension_numbers)
引數 | 類型 | 語義 |
---|---|---|
lhs |
XlaOp |
T 類型的陣列 |
rhs |
XlaOp |
T 類型的陣列 |
dimension_numbers |
DotDimensionNumbers |
收縮和批量維度編號 |
與點號類似,但允許同時為 lhs
和 rhs
指定合約及批次維度編號。
點維度號碼欄位 | 類型 | 語義 |
---|---|---|
lhs_contracting_dimensions
|
重複的 int64 | lhs 約定尺寸
|
rhs_contracting_dimensions
|
重複的 int64 | rhs 約定尺寸
|
lhs_batch_dimensions
|
重複的 int64 | lhs 個批次維度號碼 |
rhs_batch_dimensions
|
重複的 int64 | rhs 個批次維度號碼 |
DotGeneral 會對 dimension_numbers
中指定的約定維度執行產品總和。
與 lhs
和 rhs
相關聯的約定維度編號不必相同,但尺寸大小必須相同。
包含合約維度編號的範例:
lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }
rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);
DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }
與 lhs
和 rhs
相關聯的批次維度編號必須具有相同的維度大小。
包含批次維度編號的範例 (批次大小 2、2x2 矩陣):
lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
輸入 | 輸出 | 語義 |
---|---|---|
[b0, m、k] dot [b0、k、n] |
[b0、m、n] | Batch Matmul |
[b0, b1, m, k] dot [b0, b1, k, n] |
[b0、b1、m、n] | Batch Matmul |
隨後,產生的維度編號是從批次維度開始,然後是 lhs
非合約/非批次維度,最後是 rhs
非合約/非批次維度。
DynamicSlice
另請參閱 XlaBuilder::DynamicSlice
。
DynamicSlice 會在動態 start_indices
的輸入陣列中擷取子陣列。每個維度的切片大小都會在 size_indices
中傳遞,用來指定每個維度中專屬片段間隔的終點:[start, start + size)。start_indices
的形狀必須是 == 1,尺寸大小等於 operand
的排名。
DynamicSlice(operand, start_indices, size_indices)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的 N 維陣列 |
start_indices |
N XlaOp 序列 |
N 純量整數清單,內含每個維度切片的起始索引。值必須大於或等於零。 |
size_indices |
ArraySlice<int64> |
N 整數清單,內含各維度的區塊大小。每個值必須嚴格大於零,且起始 + 大小必須小於或等於尺寸大小,以免包裝模數尺寸大小。 |
執行配量之前,請先對 [1, N)
中每個索引 i
套用下列轉換,計算有效的配量索引:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
這可確保擷取的配量一律與運算元陣列相關。如果套用轉換前該切片是在邊界內,則轉換不會產生任何作用。
1D 範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}
DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}
DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicUpdateSlice
另請參閱 XlaBuilder::DynamicUpdateSlice
。
DynamicUpdateSlice 會產生結果,也就是輸入陣列 operand
的值,並在 start_indices
覆寫片段 update
。update
的形狀會決定結果的子陣列形狀。start_indices
的形狀必須是 == 1,尺寸大小等於 operand
的排名。
DynamicUpdateSlice(operand, update, start_indices)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的 N 維陣列 |
update |
XlaOp |
含有配量更新項目 T 類型的 N 維度陣列。每個更新形狀的維度都必須大於零,且起始 + 更新必須小於或等於各維度的運算元大小,以免產生範圍外更新索引。 |
start_indices |
N XlaOp 序列 |
N 純量整數清單,內含每個維度切片的起始索引。值必須大於或等於零。 |
執行配量之前,請先對 [1, N)
中每個索引 i
套用下列轉換,計算有效的配量索引:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
這樣可確保更新的配量一律與運算元陣列相關。如果套用轉換前該切片是在邊界內,則轉換不會產生任何作用。
1D 範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}
DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0, 13.0},
{14.0, 15.0},
{16.0, 17.0} }
let s = {1, 1}
DynamicUpdateSlice(b, u, s) produces:
{ {0.0, 1.0, 2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }
元素優先二元算術運算
另請參閱 XlaBuilder::Add
。
支援一組元素優先的二元算術運算。
Op(lhs, rhs)
其中 Op
是 Add
(新增)、Sub
(減號)、Mul
(乘法)、Div
(除法)、Rem
(其餘)、Max
(最大)、Min
(最小)、LogicalAnd
(邏輯 AND) 或 LogicalOr
(邏輯 OR)。
引數 | 類型 | 語義 |
---|---|---|
lhs |
XlaOp |
左側運算元:T 類型的陣列 |
rhs |
XlaOp |
右側運算元:T 類型的陣列 |
引數的形狀必須相似或相容。請參閱「廣播」說明文件,瞭解何謂與形狀相容。運算結果具有形狀,也就是播送兩個輸入陣列的結果。在這個變體中,除非其中一個運算元是純量,否則「不支援」不同排名陣列之間的作業。
當 Op
為 Rem
時,結果的正負號取自被除數,且結果的絕對值一律小於除數的絕對值。
整數除法溢位 (帶正負號/無正負號除以零,或具有 -1
的 INT_SMIN
的帶正負號除號) 會產生定義值的實作。
下列作業提供了支援不同排名廣播支援的替代變化版本:
Op(lhs, rhs, broadcast_dimensions)
其中 Op
與上述相同。此運算的變體應用於不同排名陣列之間的算術運算,例如將矩陣新增至向量。
額外的 broadcast_dimensions
運算元是整數切片,可用於將較低排名運算元的排名擴充到較高排名運算元的排名。broadcast_dimensions
會將排名較低的形狀維度對應至較高排名形狀的維度。展開形狀的未對應的維度會填入大小一。去產生維度廣播,然後沿著這些去產生維度播送形狀,讓兩個運算元的形狀相等。如要進一步瞭解語意,請參閱廣播頁面。
元素對照比較作業
另請參閱 XlaBuilder::Eq
。
支援一組標準元素元素二元比較作業。請注意,比較浮點類型時,適用標準 IEEE 754 浮點比較語意。
Op(lhs, rhs)
其中 Op
是 Eq
(等於)、Ne
(不等於)、Ge
(大於或等於)、Gt
(大於)、Le
(小於或等於)、Lt
(小於) 其中之一。另一組運算子:EqTotalOrder、NTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder 相同,但額外支援透過浮點數的總訂單,但強制執行 -NaN < -Inf < - +Finite < -0 <N.<N ><N.
引數 | 類型 | 語義 |
---|---|---|
lhs |
XlaOp |
左側運算元:T 類型的陣列 |
rhs |
XlaOp |
右側運算元:T 類型的陣列 |
引數的形狀必須相似或相容。請參閱「廣播」說明文件,瞭解何謂與形狀相容。運算結果具有形狀,也就是播送兩個輸入陣列,其類型為 PRED
。在這個變因中,系統「不支援」不同排名陣列之間的作業,除非其中一個運算元是純量。
下列作業提供了支援不同排名廣播支援的替代變化版本:
Op(lhs, rhs, broadcast_dimensions)
其中 Op
與上述相同。該運算的變體應用於不同排名陣列之間的運算,例如將矩陣新增至向量。
額外的 broadcast_dimensions
運算元是整數切片,用於指定廣播運算元要使用的維度。如要進一步瞭解語意,請參閱廣播頁面。
元素重要一元函式
XlaBuilder 支援下列元素相關一元函式:
Abs(operand)
元素依序抽象 x -> |x|
。
Ceil(operand)
元素優先會下降 x -> ⌈x⌉
。
Cos(operand)
元素相關餘弦值 x -> cos(x)
。
Exp(operand)
元素端自然指數 x -> e^x
。
Floor(operand)
元素定義底價 x -> ⌊x⌋
。
Imag(operand)
複雜 (或實) 形狀中的元素順位想像部分。天氣為x -> imag(x)
。如果運算元是浮點類型,就會傳回 0。
IsFinite(operand)
會測試 operand
的每個元素是否為有限的無限大,即非正無限大,且不是 NaN
。傳回與輸入具有相同形狀的 PRED
值陣列,其中每個元素都是 true
,只有在對應的輸入元素有限時。
Log(operand)
元素端自然對數 x -> ln(x)
。
LogicalNot(operand)
元素邏輯邏輯不是 x -> !(x)
。
Logistic(operand)
元素重要邏輯函式運算 x ->
logistic(x)
。
PopulationCount(operand)
會計算 operand
各元素中設定的位元數。
Neg(operand)
元素相關否定 x -> -x
。
Real(operand)
是複雜 (或實) 形狀中的元素真正部分。天氣為x -> real(x)
。如果運算元是浮點類型,就會傳回相同的值。
Rsqrt(operand)
平方根運算 x -> 1.0 / sqrt(x)
的元素相互對應。
Sign(operand)
元素相關符號運算 x -> sgn(x)
,其中
\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]
使用 operand
元素類型的比較運算子。
Sqrt(operand)
元素方平方根運算 x -> sqrt(x)
。
Cbrt(operand)
元素優先立方根根作業 x -> cbrt(x)
。
Tanh(operand)
元素優先的雙曲正切 x -> tanh(x)
。
Round(operand)
元素順位四捨五入,與零連結。
RoundNearestEven(operand)
元素順位進位,與最接近的偶數相關聯。
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
函式的運算元 |
該函式會套用至 operand
陣列中的每個元素,因此產生具有相同形狀的陣列。operand
可以是純量 (排名 0)。
Fft
XLA FFT 運算會實作正向和反向 Fourier 轉換,用於真實和複雜的輸入/輸出。系統支援高達 3 軸的多維度 FFT。
另請參閱 XlaBuilder::Fft
。
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
我們是 Fourier 轉換的陣列。 |
fft_type |
FftType |
請參閱下表。 |
fft_length |
ArraySlice<int64> |
轉換軸的時間網域長度。因為 RFFT(fft_length=[16]) 的輸出形狀與 RFFT(fft_length=[17]) 相同,IRFFT 才能設定最內軸適中的大小,因此需要特別注意。 |
FftType |
語義 |
---|---|
FFT |
轉向複雜至複雜的 FFT。形狀維持不變。 |
IFFT |
從複雜到複雜的 FFT 的逆轉。形狀維持不變。 |
RFFT |
轉送真實至複雜的 FFT。如果 fft_length[-1] 是非零值,則最內軸的形狀會縮減為 fft_length[-1] // 2 + 1 ,並省略轉換訊號超出 Nyquist 頻率範圍之外的反轉換部分。 |
IRFFT |
反向實際至複雜的 FFT (意即處理複雜且傳回真實的 FFT)。如果 fft_length[-1] 是非零值,則最內軸的形狀會展開為 fft_length[-1] ,從 1 的對向變化的對向 fft_length[-1] // 2 + 1 項目推斷出已轉換信號的一部分,超出 Nyquist 頻率。 |
多維度 FFT
如果提供多個 fft_length
,相當於對最內軸的每一個軸套用串聯 FFT 運算。請注意,針對實際 > 複雜且複雜的真實情況,最內層的軸轉換是先 (有效) 執行 (RFFT,最後用於 IRFFT),因此最內層的軸是改變大小的原因。其他軸轉換則會複雜->複雜。
實作詳情
CPU FFT 由 Eigen 的 TensorFFT 提供支援。GPU FFT 會使用 cuFFT。
收集
XLA 收集作業會將輸入陣列的多個配量 (每個配量可能有所不同的執行階段偏移) 拼接在一起。
一般語意
另請參閱 XlaBuilder::Gather
。如需更直覺的說明,請參閱下方的「實用資訊」部分。
gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
我們收集資料的來源陣列。 |
start_indices |
XlaOp |
這個陣列包含我們收集的切片起始索引。 |
index_vector_dim |
int64 |
start_indices 中「包含」起始索引的維度。詳細說明如下。 |
offset_dims |
ArraySlice<int64> |
輸出形狀中的維度組合,可偏移到從運算元切割的陣列中。 |
slice_sizes |
ArraySlice<int64> |
slice_sizes[i] 是維度 i 的區塊邊界。 |
collapsed_slice_dims |
ArraySlice<int64> |
每個區塊中收合的維度組合。這些維度的大小必須為 1。 |
start_index_map |
ArraySlice<int64> |
這張地圖說明如何將 start_indices 中的索引對應至運算元。 |
indices_are_sorted |
bool |
是否保證會依照呼叫端排序索引。 |
為方便起見,我們在輸出陣列中,將非 offset_dims
中的維度標示為 batch_dims
。
輸出內容是排名 batch_dims.size
加上 offset_dims.size
的陣列。
operand.rank
必須等於 offset_dims.size
和 collapsed_slice_dims.size
的總和。此外,slice_sizes.size
也必須等於 operand.rank
。
如果 index_vector_dim
等於 start_indices.rank
,我們會隱含 start_indices
為結尾的 1
維度 (也就是說,如果 start_indices
是 [6,7]
形狀,而 index_vector_dim
是 2
,我們會隱性將 start_indices
的形狀視為 [6,7,1]
)。
沿著維度 i
的輸出陣列邊界的計算方式如下:
如果
batch_dims
中有i
(例如,部分k
等於batch_dims[k]
),我們會從start_indices.shape
中挑選相應的維度邊界,也就是在k
<index_vector_dim
時略過start_indices.shape.dims
(也就是說,如果k
<index_vector_dim
,則為start_indices.shape.dims
[k
+1
])。如果
offset_dims
中有i
(例如,部分k
等於offset_dims
[k
]),我們會在計算collapsed_slice_dims
之後,選擇slice_sizes
的對應邊界 (也就是說,我們會挑選adjusted_slice_sizes
[k
],其中adjusted_slice_sizes
為slice_sizes
,並移除了索引collapsed_slice_dims
的邊界)。
一般而言,與指定輸出索引 Out
相對應的運算元索引 In
會以下列方式計算:
讓
G
= {Out
[k
] forbatch_dims
} 中的k
。使用G
分離向量S
,以便S
[i
] =start_indices
[Combine(G
,i
)],其中合併(A, b) 會在位置index_vector_dim
插入 b。請注意,即使G
為空白,也明確定義了這項設定:如果G
為空白,則S
=start_indices
。使用
start_index_map
分批S
,使用S
建立起始索引S
in
。operand
更精確:S
in
[start_index_map
[k
]] =S
[k
] (如果k
<start_index_map.size
)。S
in
[_
] =0
其他。
根據
collapsed_slice_dims
集,將索引散佈在Out
中的偏移維度,藉此將索引O
in
建立為operand
。更精確:O
in
[remapped_offset_dims
(k
)] =Out
[offset_dims
[k
]] 表示k
<offset_dims.size
(下方定義remapped_offset_dims
)。O
in
[_
] =0
其他。
In
是O
in
+S
in
,其中 + 是附加元素。
remapped_offset_dims
是具有網域 [0
、offset_dims.size
) 和範圍 [0
, operand.rank
) \ collapsed_slice_dims
的單音函式。舉例來說offset_dims.size
為 4
,operand.rank
為 6
,collapsed_slice_dims
為 {0
, 2
},而 collapsed_slice_dims
為 {0
→1
、1
→3
、2
→4
、3
→5
}。remapped_offset_dims
如果將 indices_are_sorted
設為 True,XLA 會假設 start_indices
已依使用者排序 (以 start_index_map
遞增順序排列)。如果不是,系統便會定義語意。
非正式說明及範例
一般而言,輸出陣列中的每個索引 Out
都會對應至運算元陣列中的元素 E
,計算方式如下:
我們會使用
Out
中的批次維度,從start_indices
查詢起始索引。我們會使用
start_index_map
將起始索引 (大小可能小於 operand.rank 的運算值) 對應至「完整」起始索引,再到operand
中。我們會使用完整的起始索引,動態分割大小為
slice_sizes
的片段。我們會收合
collapsed_slice_dims
維度來重新調整切片。由於所有收合的切片維度都必須限制 1,因此這個重形一律合法。我們會使用
Out
中的偏移維度為這個配量建立索引,以取得對應至輸出索引Out
的輸入元素E
。
在後續的所有範例中,index_vector_dim
都設為 start_indices.rank
- 1
。較有趣的 index_vector_dim
值並不會徹底改變作業方式,而是讓視覺化呈現作業更加繁瑣。
如要瞭解上述所有功能如何搭配運作,讓我們來看看從 [16,11]
陣列收集 5 個形狀 [8,6]
的部分範例。片段在 [16,11]
陣列中的位置可表示為 S64[2]
形狀的索引向量,因此 5 個位置的組合能以 S64[5,2]
陣列表示。
接著,收集作業的行為可描述為採用 [G
、O
0
、O
1
] 的索引轉換,也就是輸出形狀中的索引,然後以下列方式將其對應至輸入陣列中的元素:
我們先使用 G
從收集索引陣列選取 (X
,Y
) 向量。索引 [G
,O
0
,O
] 的輸出陣列中的元素是索引 [X
+O
0
,Y
+O
] 的輸入陣列中的元素。1
1
slice_sizes
為 [8,6]
,可決定 O0
和 O1
的範圍,進而決定片段的邊界。
這項收集作業可做為批次動態配量,並使用 G
做為批次維度。
收集索引可以包含多種資訊。舉例來說,對於上述範例使用「收集索引」形狀 [4,5,2]
的較通用範例,將會轉譯下列索引:
同樣地,這會是批次動態的區塊 G
0
和 G
1
做為批次維度。配量大小仍為 [8,6]
。
XLA 中的收集作業會以下列方式概略說明上述非正式語意:
我們可以設定輸出形狀中的哪些維度是偏移維度 (最後一個範例中含有
O
0
、O
1
的維度)。輸出批次維度 (最後一個範例包含G
0
的維度、G
1
) 會定義為非偏移維度的輸出維度。輸出形狀中明確存在的輸出偏移維度數量可能小於輸入排名。這些「缺少」的維度 (明確列為
collapsed_slice_dims
) 的維度大小必須是1
。由於配量大小為1
,因此它們唯一有效的索引是0
,省略這些元素並不會造成混淆。從「Gather Indices」陣列擷取的配量 (在最後一個例子中為
X
、Y
) 的元素可能少於輸入陣列排名,而明確對應表示索引應如何擴展,才能獲得與輸入內容相同的排名。
為最後一個例子,我們使用 (2) 和 (3) 實作 tf.gather_nd
:
G
0
和 G
1
會照常從收集索引陣列中分割起始索引,只是起始索引只有一個元素 X
。同理,只有一個輸出偏移索引值為 O
0
。不過,在做為索引用於輸入陣列之前,這些索引會依據「Gather Index Mapping」(正式說明中的 start_index_map
) 和「Offset Mapping」(正式說明中的 remapped_offset_dims
) 展開至 [X
、0
] 和 [0
、O
0
]。G
X
0
0
0
0
0
0
O
O
O
G
G
G
1
1
GatherIndices
tf.gather_nd
在這個案例中,slice_sizes
為 [1,11]
。基本上,這表示收集索引陣列中的每個索引 X
都會挑選整個資料列,而結果是這些資料列的串連結果。
GetDimensionSize
另請參閱 XlaBuilder::GetDimensionSize
。
傳回運算元的指定尺寸大小。運算元必須是陣列形狀。
GetDimensionSize(operand, dimension)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
N 維度輸入陣列 |
dimension |
int64 |
以間隔 [0, n) 為單位的值,用來指定維度 |
SetDimensionSize
另請參閱 XlaBuilder::SetDimensionSize
。
設定 XlaOp 指定維度的動態大小。運算元必須是陣列形狀。
SetDimensionSize(operand, size, dimension)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
n 維度輸入陣列。 |
size |
XlaOp |
int32 代表執行階段動態大小。 |
dimension |
int64 |
以間隔 [0, n) 為單位的值,用來指定維度。 |
結果會傳遞運算元,以編譯器追蹤動態維度。
下游縮減運算會忽略填充值。
let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;
// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);
// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);
GetTupleElement
另請參閱 XlaBuilder::GetTupleElement
。
將索引時間常數值編入索引為元組。
這個值必須為編譯時間常數,這樣形狀推論才能判斷結果值的類型。
這類似於 C++ 中的 std::get<int N>(t)
。概念:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
另請參閱 tf.tuple
。
動態內廣告
另請參閱 XlaBuilder::Infeed
。
Infeed(shape)
引數 | 類型 | 語義 |
---|---|---|
shape |
Shape |
從動態內介面讀取的資料形狀。形狀的版面配置欄位必須設為符合傳送至裝置的資料版面配置;否則,其行為為未定義。 |
這個外掛程式能從裝置的隱含動態內串流介面讀取單一資料項目,將資料解讀為指定的形狀和版面配置,然後傳回資料的 XlaOp
。運算中可以有多個動態饋給內作業,但每個動態饋給內作業的順序必須一致。舉例來說,以下程式碼中的兩個動態饋給,由於迴圈之間有依附元件,因此順序總計。
result1 = while (condition, init = init_value) {
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
}
不支援巢狀元組形狀。如果是空白元組形狀,動態內作業實際上是免人工管理,在不讀取裝置內動態時的任何資料的情況下繼續作業。
器皿打擊樂
另請參閱 XlaBuilder::Iota
。
Iota(shape, iota_dimension)
在裝置上建構常數常值,而非可能的大型主機傳輸作業。建立具有指定形狀的陣列,並保留從 0 開始的值,並隨指定維度遞增 1 的值。如為浮點類型,產生的陣列等同於 ConvertElementType(Iota(...))
。其中 Iota
為積分類型,而轉換是浮點類型。
引數 | 類型 | 語義 |
---|---|---|
shape |
Shape |
由 Iota() 建立的陣列形狀 |
iota_dimension |
int64 |
要遞增的維度。 |
例如,Iota(s32[4, 8], 0)
會傳回
[[0, 0, 0, 0, 0, 0, 0, 0 ],
[1, 1, 1, 1, 1, 1, 1, 1 ],
[2, 2, 2, 2, 2, 2, 2, 2 ],
[3, 3, 3, 3, 3, 3, 3, 3 ]]
可退貨 (費用:Iota(s32[4, 8], 1)
)
[[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ]]
地圖
另請參閱 XlaBuilder::Map
。
Map(operands..., computation)
引數 | 類型 | 語義 |
---|---|---|
operands |
N XlaOp 秒的順序 |
N 類型陣列 T0..T{N-1} |
computation |
XlaComputation |
具有 N 類型 T 和任意類型之 N 參數的 T_0, T_1, .., T_{N + M -1} -> S 類型計算 |
dimensions |
int64 陣列 |
地圖維度陣列 |
對指定的 operands
陣列套用純量函式,產生相同維度的陣列,其中每個元素都是套用於輸入陣列中對應元素的對應函式結果。
對應函式是設有限制的任意運算,它具有 N 個純量類型 T
輸入和 S
類型的單一輸出。輸出結果的維度與運算元相同,只是元素類型 T 已替換為 S。
例如:Map(op1, op2, op3, computation, par1)
會在輸入陣列中的每個 (多維度) 索引之間對應 elem_out <-
computation(elem1, elem2, elem3, par1)
,以產生輸出陣列。
OptimizationBarrier
禁止任何最佳化傳遞,使其無法在屏障中移動運算。
確保所有輸入都會先評估,再有依附於障礙輸出內容的運算子。
防溢乳墊
另請參閱 XlaBuilder::Pad
。
Pad(operand, padding_value, padding_config)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的陣列 |
padding_value |
XlaOp |
要填入新增的邊框間距 T 類型的純量 |
padding_config |
PaddingConfig |
兩個邊緣 (低、高) 以及每個尺寸元素之間的邊框間距 |
透過陣列前後的邊框間距,以及指定 padding_value
的陣列元素之間,展開指定的 operand
陣列。padding_config
會指定邊緣邊框間距的量,以及每個維度的內部邊框間距。
PaddingConfig
是 PaddingConfigDimension
的重複欄位,其中包含每個維度的三個欄位:edge_padding_low
、edge_padding_high
和 interior_padding
。
edge_padding_low
和 edge_padding_high
會分別指定在索引 0 旁 (位於索引 0 旁) 與高端 (位於最高索引旁) 新增的邊框間距量。邊緣邊框間距的量可以是負數,負值邊框間距的絕對值代表要從指定維度中移除的元素數量。
interior_padding
會指定在每個維度中任兩個元素之間加入的邊框間距量,而且不得為負數。內部邊框間距是在邏輯間距之前發生,因此如果邊緣邊框間距為負數,元素會從內部填充運算元中移除。
如果邊緣邊框間距組合全都為 (0, 0),且內部邊框間距值為 0,則這項作業即為免人工管理。下圖顯示二維陣列的不同 edge_padding
和 interior_padding
值範例。
最佳化建議
另請參閱 XlaBuilder::Recv
。
Recv(shape, channel_handle)
引數 | 類型 | 語義 |
---|---|---|
shape |
Shape |
要接收的資料形狀 |
channel_handle |
ChannelHandle |
每個傳送/接收組合的專屬 ID |
在共用相同管道控制代碼的其他運算中,從 Send
指令接收指定形狀的資料。傳回接收資料的 XlaOp。
Recv
作業的用戶端 API 代表同步通訊。不過,指令會在內部分解為 2 個 HLO 指令 (Recv
和 RecvDone
),以啟用非同步資料移轉。另請參閱 HloInstruction::CreateRecv
和 HloInstruction::CreateRecvDone
。
Recv(const Shape& shape, int64 channel_id)
分配接收相同 channel_id 的 Send
指令資料所需的資源。傳回已分配資源的結構定義,下列 RecvDone
指令可用於等待資料移轉完成。結構定義是 {receive buffer (shape)、要求 ID (U32)} 的元組,只能用於 RecvDone
指令。
RecvDone(HloInstruction context)
請依據 Recv
指令建立的結構定義,等待資料移轉作業完成並傳回接收的資料。
遏止
另請參閱 XlaBuilder::Reduce
。
將縮減函式平行套用到一或多個陣列。
Reduce(operands..., init_values..., computation, dimensions)
引數 | 類型 | 語義 |
---|---|---|
operands |
N XlaOp 的序列 |
T_0, ..., T_{N-1} 類型的 N 陣列。 |
init_values |
N XlaOp 的序列 |
T_0, ..., T_{N-1} 類型的 N 純量。 |
computation |
XlaComputation |
T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的運算。 |
dimensions |
int64 陣列 |
要減少的未排序維度陣列。 |
在此情況下:
- N 不得小於 1。
- 計算作業必須具有「大致」關聯 (如下所示)。
- 所有輸入陣列的維度都必須相同。
- 所有初始值均須在
computation
下形成身分。 - 如果值為
N = 1
,則Collate(T)
為T
。 - 如果為
N > 1
,則Collate(T_0, ..., T_{N-1})
是T
類型的N
元素組合。
這項作業會將每個輸入陣列的一或多個維度縮減為純量。每個傳回的陣列的排名為 rank(operand) - len(dimensions)
。運算的輸出內容為 Collate(Q_0, ..., Q_N)
,其中 Q_i
是 T_i
類型的陣列,說明如下的維度。
不同的後端可以重新連結縮減計算之間的關聯。這可能會造成數字差異,因為有些縮減函式 (例如加法) 並未與浮點值建立關聯。不過,如果資料範圍有限,則新增浮點數就足以與大多數實用的用途相關聯。
示例
使用 [10, 11,
12, 13]
值,在單一 1D 陣列中縮減一個維度時,縮減函式 f
(為 computation
) 之後,可以計算出為
f(10, f(11, f(12, f(init_value, 13)))
除此之外,還有許多其他可能性,例如
f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))
以下提供概略的虛擬程式碼範例,說明如何實作縮減功能,其中使用加總做為初始值為 0 的縮減運算。
result_shape <- remove all dims in dimensions from operand_shape
# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
# Initialize this result element
result[r0, r1...] <- 0
# Iterate over all the reduction dimensions
for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
# Increment the result element with the value of the operand's element.
# The index of the operand's element is constructed from all ri's and di's
# in the right order (by construction ri's and di's together index over the
# whole operand shape).
result[r0, r1...] += operand[ri... di]
以下是減少 2D 陣列 (矩陣) 的範例。形狀為 2,尺寸 0 為 2,尺寸 1 則為 3:
使用「新增」函式縮小維度 0 或維度 1 的結果:
請注意,兩個縮減結果都是 1D 陣列。在此圖中,為一欄顯示一欄,另一列則為一欄,方便您一目瞭然。
以下為較複雜的範例,以下是 3D 陣列。其排名為 3、大小為 4 的維度 0、大小 2 的維度 1 及大小 3 的維度 2。為簡單起見,1 到 6 的值會在維度 0 之間複製。
如同 2D 範例,我們只縮減一個維度。比方說,縮小維度 0 後得到的結果是 排名-2 的陣列,而維度 0 的所有值均會併入純量:
| 4 8 12 |
| 16 20 24 |
如果我們縮減維度 2,也會取得排名 2 的陣列,其中維度 2 的所有值會縮放成純量:
| 6 15 |
| 6 15 |
| 6 15 |
| 6 15 |
請注意,輸出中會保留其他維度之間的相對順序,但由於排名改變,部分維度可能會獲派新的數字。
您也可以減少多個維度。新增維度 0 和 1 會產生 1D 陣列 [20, 28, 36]
。
將所有維度縮小 3D 陣列會產生純量 84
。
緩解壓力
設為 N > 1
時,縮減函式應用程式會稍微複雜,因為同時會套用至所有輸入。運算元會按照下列順序提供給運算:
- 針對第一個運算元執行縮減值
- ...
- 針對第 N' 運算元執行縮減值
- 第一個運算元的輸入值
- ...
- N' 運算元的輸入值
舉例來說,假設下列縮減函式可用來平行計算 1-D 陣列的最大值和 argmax:
f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
if value >= max:
return (value, index)
else:
return (max, argmax)
針對 1-D 輸入陣列 V = Float[N], K = Int[N]
以及 init 值 I_V = Float, I_K = Int
,僅縮小輸入維度的結果 f_(N-1)
與下列遞迴應用程式相同:
f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
將此縮減值套用到值陣列,以及依序索引 (例如 iota) 陣列會共同處理陣列,並傳回包含最大值與比對索引值的元組。
ReducePrecision
另請參閱 XlaBuilder::ReducePrecision
。
模擬將浮點值轉換為較低精確度格式 (例如 IEEE-FP16) 以及改回原始格式的效果。儘管不是每個硬體實作都支援所有位元大小,但可以任意指定低精確度格式的指數和 mantissa 位元數量。
ReducePrecision(operand, mantissa_bits, exponent_bits)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
浮點類型 T 的陣列。 |
exponent_bits |
int32 |
較低精確度格式的指數位元數 |
mantissa_bits |
int32 |
較低精確度格式的 mantissa 位元數 |
結果是 T
類型的陣列。輸入值會四捨五入為以指定量數位元數表示的最接近值 (使用「幾天脈動」語意),任何超過指數位元數量指定範圍的值都會歸入正或負無限大。NaN
值會保留,但可轉換為標準 NaN
值。
較低精度格式必須至少有一個指數位元 (為了區分零值與無限大,因為兩者皆有零尾數),且必須有非負數的 mantissa 位元數。指數或 mantisa 位元數可能會超過 T
類型的對應值,這時轉換的對應部分只是免人工管理。
ReduceScatter
另請參閱 XlaBuilder::ReduceScatter
。
ReduceScatter 是一種集體作業,可有效執行 AllReduce,並將結果分成 scatter_dimension
的 shard_count
區塊,而備用資源群組中的備用資源 i
會收到 ith
資料分割。
ReduceScatter(operand, computation, scatter_dim, shard_count,
replica_group_ids, channel_id)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要跨備用資源縮減的陣列或非空白陣列組合。 |
computation |
XlaComputation |
縮減運算 |
scatter_dimension |
int64 |
分散的維度。 |
shard_count |
int64 |
要分割的「scatter_dimension 」區塊數量 |
replica_groups |
int64 向量的向量 |
要執行縮減的群組 |
channel_id |
選用 int64 |
跨模組通訊的選用頻道 ID |
- 如果
operand
是陣列組合,系統會對元組的每個元素執行縮減散佈器。 replica_groups
是執行縮減程序的備用資源群組清單 (您可以使用ReplicaId
擷取目前備用資源的備用資源 ID)。每個群組中的備用資源順序會決定所有縮減結果的分散順序。replica_groups
必須空白 (在這種情況下,所有備用資源都屬於單一群組),或包含與備用資源數量相同的元素數量。如有多個備用資源群組,這些群組的大小都必須相同。舉例來說,replica_groups = {0, 2}, {1, 3}
會在備用資源0
與2
之間進行縮減,以及1
和3
,然後分散結果。shard_count
是每個備用資源群組的大小。如果replica_groups
為空白,就必須使用這個參數。如果replica_groups
不是空白,shard_count
必須等於每個備用資源群組的大小。channel_id
用於跨模組通訊:只有具有相同channel_id
的reduce-scatter
作業才能彼此通訊。
輸出形狀是輸入形狀,scatter_dimension
縮小了 shard_count
倍。舉例來說,如果有兩個備用資源,而運算元在兩個備用資源上分別具有 [1.0, 2.25]
和 [3.0, 5.25]
值,則這項運算的輸出值 (scatter_dim
為 0
) 為第一個備用資源的 [4.0]
,而第二個備用資源的 scatter_dim
為 [7.5]
。
ReduceWindow
另請參閱 XlaBuilder::ReduceWindow
。
將縮減函式套用至 N 多維陣列序列的每個窗口中的所有元素,產生 N 多維度陣列的單色或元組做為輸出。每個輸出陣列的元素數量與視窗有效位置數量相同。集區層可以使用 ReduceWindow
表示。與 Reduce
類似,套用的 computation
一律會在左側傳遞 init_values
。
ReduceWindow(operands..., init_values..., computation, window_dimensions,
window_strides, padding)
引數 | 類型 | 語義 |
---|---|---|
operands |
N XlaOps |
T_0,..., T_{N-1} 類型的 N 多維度陣列序列,每個陣列都代表視窗位置的基礎區域。 |
init_values |
N XlaOps |
縮減作業的 N 起始值,N 運算元各一個。詳情請參閱減少。 |
computation |
XlaComputation |
T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的縮減函式,可套用至所有輸入運算元中每個視窗中的元素。 |
window_dimensions |
ArraySlice<int64> |
視窗維度值的整數陣列 |
window_strides |
ArraySlice<int64> |
區間步值的整數陣列 |
base_dilations |
ArraySlice<int64> |
基本相除值的整數陣列 |
window_dilations |
ArraySlice<int64> |
窗體除法值的整數陣列 |
padding |
Padding |
視窗的邊框間距類型 (Padding::kSame,邊框間距類型為 1 時填充與輸入相同) 或 Padding::kValid (不使用邊框間距,在視窗不再符合時就會「停止」) |
在此情況下:
- N 不得小於 1。
- 所有輸入陣列的維度都必須相同。
- 如果值為
N = 1
,則Collate(T)
為T
。 - 如果為
N > 1
,則Collate(T_0, ..., T_{N-1})
是(T0,...T{N-1})
類型的N
元素組合。
下方的程式碼和圖表提供 ReduceWindow
的使用範例。輸入是大小 [4x6] 的矩陣, window_dimensions 和 window_stride_dimensions 皆為 [2x3]。
// Create a computation for the reduction (maximum).
XlaComputation max;
{
XlaBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
max = builder.Build().value();
}
// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
input,
/*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
*max,
/*window_dimensions=*/{2, 3},
/*window_stride_dimensions=*/{2, 3},
Padding::kValid);
維度中的跨距 1 代表視窗的位置與相鄰視窗位置為 1 個元素。為了指定沒有視窗彼此重疊,window_stride_dimensions 應等於 window_dimensions。下圖說明兩個不同步差值的用法。邊框間距會套用至輸入的每個維度,且計算方式與輸入內容在邊框間距之後的維度相同。
對於一般的邊框間距範例,請考慮用維度 3
計算縮減視窗最小值 (初始值為 MAX_FLOAT
),跨輸入陣列 [10000, 1000, 100, 10, 1]
計算 2
。邊框間距 kValid
會計算兩個有效期間 ([10000, 1000, 100]
和 [100, 10, 1]
) 的最小值,因此輸出 [100, 1]
。加上 kSame
填充陣列後,會先填充陣列,讓縮減視窗之後的形狀與第 2 個體的輸入「相同」,方法是在兩邊新增初始元素,取得 [MAX_VALUE, 10000, 1000, 100, 10, 1,
MAX_VALUE]
。在填充陣列上執行縮減視窗,會在三個視窗 [MAX_VALUE, 10000, 1000]
、[1000, 100, 10]
、[10, 1, MAX_VALUE]
和收益 [1000, 10, 1]
上執行。
縮減函式的評估順序為任意值,且可能不具確定性。因此,縮減函式不應過於敏感,以防重新建立關聯。詳情請參閱 Reduce
的情境中有關關聯性的討論。
ReplicaId
另請參閱 XlaBuilder::ReplicaId
。
傳回備用資源的專屬 ID (U32 純量)。
ReplicaId()
每個備用資源的專屬 ID 都是間隔 [0, N)
中的無正負號整數,其中 N
是備用資源數量。由於所有備用資源都執行相同的程式,因此該程式中的 ReplicaId()
呼叫會在每個備用資源上傳回不同的值。
重塑
另請參閱 XlaBuilder::Reshape
和 Collapse
作業。
將陣列的維度重塑為新設定。
Reshape(operand, new_sizes)
Reshape(operand, dimensions, new_sizes)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的陣列 |
dimensions |
int64 個向量 |
以及維度收合順序 |
new_sizes |
int64 個向量 |
新維度的向量大小 |
概念上,將陣列重新塑造為一維資料值的向量,然後再將這個向量修正為新的形狀。輸入引數是 T 類型的任意陣列、維度索引的編譯時間常數向量,以及結果維度大小的編譯時間常數向量。dimension
向量中的值 (如有指定) 必須是 T 所有維度的排列組合;如未指定,則預設值為 {0, ..., rank - 1}
。dimensions
中的維度順序是從迴圈巢狀結構中最慢的維度 (最主要) 到快速變動的維度 (最小),也就是將輸入陣列收合為單一維度。new_sizes
向量會決定輸出陣列的大小。new_sizes
中索引 0 的值是維度 0 的大小,索引 1 的值是維度 1 的大小,依此類推。new_size
維度的乘積必須等於運算元維度大小的乘積。將收合的陣列修正為 new_sizes
定義的多維度陣列時,new_sizes
中的維度會由最低變化 (最主要) 排序,變為變化最快 (最次要)。
舉例來說,假設 v 是包含 24 個元素的陣列:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
{20, 21, 22}, {25, 26, 27},
{30, 31, 32}, {35, 36, 37},
{40, 41, 42}, {45, 46, 47} };
Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
{31, 41, 12}, {22, 32, 42},
{15, 25, 35}, {45, 16, 26},
{36, 46, 17}, {27, 37, 47} };
let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
{11, 21}, {31, 41},
{12, 22}, {32, 42} },
{ {15, 25}, {35, 45},
{16, 26}, {36, 46},
{17, 27}, {37, 47} } };
重新形狀是特殊情況,可將單一元素陣列轉換為純量,反之亦然。比如
Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };
還原 (反轉)
另請參閱 XlaBuilder::Rev
。
Rev(operand, dimensions)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的陣列 |
dimensions |
ArraySlice<int64> |
就會反向移動 |
沿著指定的 dimensions
反轉 operand
陣列中的元素順序,產生相同形狀的輸出陣列。多維度索引的運算元陣列的每個元素,都會儲存在轉換索引的輸出陣列中。反轉每個維度中的索引並將其反轉,藉此轉換多維度索引。也就是說,如果大小為 N 的維度是反轉維度之一,其索引 i 就會轉換為 N - 1 - i。
在類神經網路的漸層計算期間,Rev
作業的其中一個用途,就是沿著兩個視窗維度反轉卷積權重陣列。
RngNormal
另請參閱 XlaBuilder::RngNormal
。
使用在 \(N(\mu, \sigma)\) 常態分配後產生的隨機數字,建構指定形狀的輸出。參數 \(\mu\) 和 \(\sigma\)以及輸出形狀必須有浮點元素類型。參數在其他參數中則必須為純量值。
RngNormal(mu, sigma, shape)
引數 | 類型 | 語義 |
---|---|---|
mu |
XlaOp |
T 類型的純量,指定產生的數字平均值 |
sigma |
XlaOp |
T 類型的純量,指定產生的標準差 |
shape |
Shape |
T 類型的輸出形狀 |
RngUniform
另請參閱 XlaBuilder::RngUniform
。
使用在間隔 \([a,b)\)平均分配後產生的隨機數字,建構指定形狀的輸出。參數和輸出元素類型必須為布林值類型、積分類型或浮點類型,且類型必須一致。CPU 和 GPU 後端目前僅支援 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,參數必須為純量值。如果 \(b <= a\) 結果已定義實作,
RngUniform(a, b, shape)
引數 | 類型 | 語義 |
---|---|---|
a |
XlaOp |
T 類型的純量指定了間隔下限 |
b |
XlaOp |
T 類型的純量指定了間隔上限 |
shape |
Shape |
T 類型的輸出形狀 |
RngBitGenerator
使用指定的演算法 (或後端預設值),以特定形狀填入相同的隨機位元,並傳回更新的狀態 (形狀與初始狀態相同) 和產生的隨機資料。
初始狀態是目前隨機號碼產生作業的初始狀態。它以及所需的形狀和有效值,皆取決於使用的演算法。
輸出內容保證是初始狀態的確定性函式,但「無法」保證在後端和不同編譯器版本之間具有確定性。
RngBitGenerator(algorithm, key, shape)
引數 | 類型 | 語義 |
---|---|---|
algorithm |
RandomAlgorithm |
要使用的 PRNG 演算法。 |
initial_state |
XlaOp |
PRNG 演算法的初始狀態。 |
shape |
Shape |
所產生資料的輸出形狀。 |
algorithm
的可用值:
rng_default
:具有後端專屬形狀要求的後端特定演算法。rng_three_fry
:ThreeFry 計數器型 PRNG 演算法。initial_state
形狀為u64[2]
搭配任意值。Salmon 等人,SC 2011,平行隨機號碼:如 1、2、3 一樣簡單。rng_philox
:用來平行產生隨機號碼的 Philox 演算法。initial_state
形狀為任意值的u64[3]
。Salmon 等人,SC 2011,平行隨機號碼:如 1、2、3 一樣簡單。
散布圖
XLA 散佈運算會產生一連串的結果,也就是輸入陣列 operands
的值,並在 updates
中使用 update_computation
將多個配量 (指定索引) 更新為 updates
中的值序列。scatter_indices
另請參閱 XlaBuilder::Scatter
。
scatter(operands..., scatter_indices, updates..., update_computation,
index_vector_dim, update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims)
引數 | 類型 | 語義 |
---|---|---|
operands |
N XlaOp 的序列 |
N 類型的 T_0, ..., T_N 陣列。 |
scatter_indices |
XlaOp |
這個陣列包含必須分離的切片起始索引。 |
updates |
N XlaOp 的序列 |
T_0, ..., T_N 類型的 N 陣列。updates[i] 包含分批使用 operands[i] 時必須使用的值。 |
update_computation |
XlaComputation |
要用於合併輸入陣列中現有值以及分批期間更新的運算。此計算應為 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) 類型。 |
index_vector_dim |
int64 |
scatter_indices 中包含起始索引的維度。 |
update_window_dims |
ArraySlice<int64> |
updates 形狀中屬於視窗尺寸的維度組合。 |
inserted_window_dims |
ArraySlice<int64> |
必須將一組視窗尺寸插入 updates 形狀中。 |
scatter_dims_to_operand_dims |
ArraySlice<int64> |
維度對應於散佈索引與運算元索引空間。系統會將此陣列解讀為將 i 對應至 scatter_dims_to_operand_dims[i] 。呈一對一且必需。 |
indices_are_sorted |
bool |
是否保證會依照呼叫端排序索引。 |
unique_indices |
bool |
呼叫端是否可保證不會重複顯示索引。 |
在此情況下:
- N 不得小於 1。
operands
[0
]、...、operands
[N-1
] 必須具有相同的維度。updates
[0
]、...、updates
[N-1
] 必須具有相同的維度。- 如果值為
N = 1
,則Collate(T)
為T
。 - 如果為
N > 1
,Collate(T_0, ..., T_N)
是T
類型的N
元素組合。
如果 index_vector_dim
等於 scatter_indices.rank
,我們會隱含 scatter_indices
的結尾 1
維度。
我們將 ArraySlice<int64>
類型的 update_scatter_dims
定義為 updates
形狀中不在 update_window_dims
中的維度組合,以遞增順序排列。
分散的引數應遵循下列限制:
每個
updates
陣列的排名都必須是update_window_dims.size + scatter_indices.rank - 1
。每個
updates
陣列中i
的維度邊界必須符合下列規定:- 如果
i
出現在update_window_dims
中 (也就是部分k
的update_window_dims
[k
]),則updates
中的i
維度邊界不得超出inserted_window_dims
的對應邊界 (也就是說,adjusted_window_bounds
[k
],其中adjusted_window_bounds
包含operand
的邊界,且已移除索引inserted_window_dims
的邊界)。operand
- 如果
update_scatter_dims
中存在i
(也就是部分k
等於update_scatter_dims
[k
]),則updates
中的i
維度邊界必須等於scatter_indices
的對應邊界,則略過index_vector_dim
(也就是scatter_indices.shape.dims
[k
],如果k
<index_vector_dim
和scatter_indices.shape.dims
[k+1
] 不同)。
- 如果
update_window_dims
必須採用遞增順序,不得包含任何重複的維度數字,並且位於[0, updates.rank)
範圍內。inserted_window_dims
必須採用遞增順序,不得包含任何重複的維度數字,並且位於[0, operand.rank)
範圍內。operand.rank
必須等於update_window_dims.size
和inserted_window_dims.size
的總和。scatter_dims_to_operand_dims.size
必須等於scatter_indices.shape.dims
[index_vector_dim
],且值必須在[0, operand.rank)
範圍內。
針對每個 updates
陣列中的指定索引 U
,系統會依照以下方式計算更新範圍的對應 operands
陣列中的對應索引 I
:
- 讓
G
= {U
[k
] forupdate_scatter_dims
} 中的k
。使用G
查詢scatter_indices
陣列中的索引向量S
,例如S
[i
] =scatter_indices
[Combine(G
,i
)],其中 Combine(A, b) 會在index_vector_dim
的位置插入 b 至 A。 - 使用
scatter_dims_to_operand_dims
對應散佈S
,使用S
將S
in
編入索引。operand
更正式:S
in
[scatter_dims_to_operand_dims
[k
]] =S
[k
] (如果k
<scatter_dims_to_operand_dims.size
)。S
in
[_
] =0
其他。
- 根據
inserted_window_dims
,透過散佈在U
的update_window_dims
分批索引,將索引W
in
新增至每個operands
陣列。更正式:W
in
[window_dims_to_operand_dims
(k
)] =U
[k
] 如果k
位於update_window_dims
,其中window_dims_to_operand_dims
是含有 [0
,update_window_dims.size
] 及範圍 [0
,operand.rank
) \inserted_window_dims
的單調函式。(例如,如果update_window_dims.size
為4
,operand.rank
為6
,而inserted_window_dims
為 {0
,2
},則window_dims_to_operand_dims
為 {0
→1
、1
→3
、2
→4
、3
→5
})。W
in
[_
] =0
其他。
I
是W
in
+S
in
,其中 + 是附加元素。
簡單來說,分批運算的定義如下。
- 使用
operands
初始化output
,例如針對operands
[J
] 陣列中的所有索引O
初始化output
,例如:
output
[J
][O
] =operands
[J
][O
]J
- 針對
updates
[J
] 陣列中的每個索引U
,以及operand
[J
] 陣列中的對應索引O
,如果O
是output
的有效索引:
(output
[0
][O
], ...,output
[N-1
][O
]) =update_computation
(output
[0
][O
], ..., ,output
[O
],updates
[0
][U
], ...,updates
[N-1
][U
])N-1
套用更新的順序不具確定性。因此,當 updates
中有多個索引參照 operands
中的同一個索引時,output
中的對應值將不具確定性。
請注意,傳遞至 update_computation
的第一個參數一律是 output
陣列的目前值,第二個參數則一律是 updates
陣列的值。當 update_computation
不可交換時,這一點尤其重要。
如果將 indices_are_sorted
設為 True,XLA 會假設 start_indices
已依使用者排序 (以 start_index_map
遞增順序排列)。如果不是,系統便會定義語意。
如果將 unique_indices
設為 True,XLA 會假設分散的所有元素都不重複。因此 XLA 可以使用非不可分割的作業。如果將 unique_indices
設為 True,且已分割的索引並非唯一,系統就會定義語意。
在不正式的情況下,散佈運算可做為收集運算的「反轉」,也就是說,散佈器會更新對應收集運算所擷取的輸入內容中的元素。
如需詳盡的非正式說明與範例,請參閱 Gather
底下的「非正式說明」部分。
請選取申請項目
另請參閱 XlaBuilder::Select
。
根據述詞陣列的值,從兩個輸入陣列的元素建構輸出陣列。
Select(pred, on_true, on_false)
引數 | 類型 | 語義 |
---|---|---|
pred |
XlaOp |
PRED 類型的陣列 |
on_true |
XlaOp |
T 類型的陣列 |
on_false |
XlaOp |
T 類型的陣列 |
陣列 on_true
和 on_false
的形狀必須相同。這也是輸出陣列的形狀。陣列 pred
的維度性必須與 on_true
和 on_false
相同,且為 PRED
元素類型。
對於 pred
的每個元素 P
,如果 P
的值為 true
,則會從 on_true
取得輸出陣列的對應元素;如果 P
的值為 false
,則會從 on_false
取得。由於播送,
pred
可以是純量 PRED
。在此情況下,如果 pred
是 true
,系統會從 on_true
完整取得輸出陣列,如果 pred
是 false
,則會從 on_false
擷取。
使用非純量 pred
的範例:
let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
純量 pred
範例:
let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
支援選取元組之間的選項。因此,元組視為純量類型。如果 on_true
和 on_false
是元組 (形狀必須相同!),則 pred
必須是 PRED
類型的純量。
SelectAndScatter
另請參閱 XlaBuilder::SelectAndScatter
。
這項作業可視為複合式作業,用於先計算 operand
陣列的 ReduceWindow
從每個時段選取元素,然後將 source
陣列分散至所選元素的索引,以建構與運算元陣列相同的形狀的輸出陣列。二進位 select
函式的用途是將元素套用到各個視窗,藉此從各個視窗中選取元素,並呼叫第一個參數的索引向量在字母序上少於第二個參數的索引向量。如果選取第一個參數,select
函式會傳回 true
,如果選取第二個參數,且函式必須保留轉換 (亦即 select(a, b)
和 select(b, c)
為 true
,則 select(a, c)
也是 true
),因此所選元素不會取決於元素針對特定時段的周遊順序。false
「scatter
」函式會套用至輸出陣列中的每個所選索引。這會採用兩個純量參數:
- 輸出陣列中所選索引的當前值
- 適用於所選索引的
source
分批值
系統會合併兩個參數並傳回純量值,用來更新輸出陣列中所選索引的值。輸出陣列的所有索引一開始會設為 init_value
。
輸出陣列的形狀與 operand
陣列相同,且 source
陣列的形狀必須與對 operand
陣列套用 ReduceWindow
運算的結果相同。SelectAndScatter
可用於反向傳播類神經網路中池化層的漸層值。
SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
視窗滑動的 T 類型陣列 |
select |
XlaComputation |
T, T -> PRED 類型的二進位計算,會套用至每個視窗中的所有元素;如果選取第一個參數,則傳回 true ;如果選取第二個參數,則傳回 false |
window_dimensions |
ArraySlice<int64> |
視窗維度值的整數陣列 |
window_strides |
ArraySlice<int64> |
區間步值的整數陣列 |
padding |
Padding |
視窗的邊框間距類型 (Padding::kSame 或 Padding::kValid) |
source |
XlaOp |
類型 T 的陣列,其值要分散 |
init_value |
XlaOp |
針對輸出陣列初始值 T 類型的純量值 |
scatter |
XlaComputation |
執行 T, T -> T 類型的二進位計算,將各個散佈來源元素及其目的地元素套用至該元素 |
下圖顯示 SelectAndScatter
的使用範例,其中 select
函式會計算其參數中的最大值。請注意,當時段重疊時,如下方圖 (2) 所示,不同時段可能會多次選取 operand
陣列的索引。在圖中,上方兩個視窗 (藍色和紅色) 都會選取值 9 的元素,而二元新增 scatter
函式會產生值 8 (2 + 6) 的輸出元素。
scatter
函式的評估順序為任意值,且可能不具確定性。因此,scatter
函式不應過於敏感,以防重新建立關聯。詳情請參閱 Reduce
的情境中有關關聯性的討論。
傳送
另請參閱 XlaBuilder::Send
。
Send(operand, channel_handle)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要傳送的資料 (T 類型的陣列) |
channel_handle |
ChannelHandle |
每個傳送/接收組合的專屬 ID |
在共用相同管道控制代碼的其他運算中,將指定運算元資料傳送至 Recv
指令。不傳回任何資料。
與 Recv
作業類似,Send
作業的用戶端 API 代表同步通訊,並在內部分解為 2 個 HLO 指令 (Send
和 SendDone
),以便啟用非同步資料移轉。另請參閱 HloInstruction::CreateSend
和 HloInstruction::CreateSendDone
。
Send(HloInstruction operand, int64 channel_id)
針對具有相同管道 ID 的 Recv
指令,將運算元以非同步方式轉移至指派的資源。傳回結構定義,下列 SendDone
指示用來等待資料移轉完成。結構定義是 {operand (shape)、要求 ID (U32)} 的元組,只能用於 SendDone
指令。
SendDone(HloInstruction context)
請依據 Send
指令建立的結構定義,等待資料移轉作業完成。該指示不會傳回任何資料。
頻道排程操作說明
每個管道 (Recv
、RecvDone
、Send
、SendDone
) 的 4 個指令的執行順序如下。
Recv
發生在Send
之前Send
發生在RecvDone
之前Recv
發生在RecvDone
之前Send
發生在SendDone
之前
當後端編譯器為每個透過通道指示進行通訊的運算作業產生線性排程時,就不會有跨運算週期。舉例來說,如果下方排程出現死結,
配量
另請參閱 XlaBuilder::Slice
。
切片可以擷取輸入陣列中的子陣列。子陣列的排名與輸入的排名相同,而且包含輸入陣列中的值,其中定界框的維度和索引是做為配量運算的引數。
Slice(operand, start_indices, limit_indices, strides)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
T 類型的 N 維陣列 |
start_indices |
ArraySlice<int64> |
N 整數清單,內含每個維度切片的起始索引。值必須大於或等於零。 |
limit_indices |
ArraySlice<int64> |
N 整數清單,其中包含每個維度切片的結尾索引 (不含)。每個值都必須大於或等於維度各自的 start_indices 值,且小於或等於維度的大小。 |
strides |
ArraySlice<int64> |
決定切片輸入步距的 N 個整數清單。配量會選擇維度 d 中的每個 strides[d] 元素。 |
1D 範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
{2.0, 3.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
Slice(b, {2, 1}, {4, 3}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
排序
另請參閱 XlaBuilder::Sort
。
Sort(operands, comparator, dimension, is_stable)
引數 | 類型 | 語義 |
---|---|---|
operands |
ArraySlice<XlaOp> |
要排序的運算元。 |
comparator |
XlaComputation |
要使用的比較子運算。 |
dimension |
int64 |
要做為排序依據的維度。 |
is_stable |
bool |
是否使用穩定排序。 |
如果只提供一個運算元:
如果運算元是排名第 1 的張量 (陣列),則結果會是排序的陣列。 如果您想以遞增順序將陣列排序,比較子應執行小於比較的結果。陣列排序之後,陣列會保留所有索引位置
i, j
,i < j
comparator(value[i], value[j]) = comparator(value[j], value[i]) = false
或comparator(value[i], value[j]) = true
。如果運算元的排名較高,運算元會根據提供的維度排序。舉例來說,如果是排名-2 張量 (矩陣),
0
的維度值會獨立排序每欄,而1
的維度值則會獨立排序每個資料列。如未提供維度編號,系統會預設選擇最後一個維度。而排序的維度則會套用相同的排序順序,與 排名-1 的情況相同。
如果提供 n > 1
運算元:
所有
n
運算元必須是具有相同維度的張量。張量的元素類型可能不同。所有運算元都會組合在一起,而非個別排序。從概念上來說,運算元會視為元組檢查索引位置
i
和j
中的每個運算元元素是否需要替換時,系統會使用2 * n
純量參數呼叫比較子,其中參數2 * k
對應k-th
運算元中i
的值,而2 * k + 1
參數則對應k-th
運算元中j
位置的值。通常,比較器會彼此比較2 * k
和2 * k + 1
參數,並可能使用其他參數組合做為綁定斷路器。結果是一個元組,其中包含按排序順序的運算元 (以及上述的維度)。元組的
i-th
運算元會對應至 Sort 的i-th
運算元。
舉例來說,假設有三個運算元 operand0 = [3, 1]
、operand1 = [42, 50]
、operand2 = [-3.0, 1.1]
,且比較子只比較 operand0
的值小於或小於的值,則排序結果會是元組 ([1, 3], [50, 42], [1.1, -3.0])
。
如果 is_stable
設為 true,則排序保證一定穩定,也就是說,如果比較子認為元素與比較器相同,則系統會保留同等值的相對順序。只有在 comparator(e1, e2) = comparator(e2, e1) = false
的情況下,e1
和 e2
才會相等。根據預設,is_stable
會設為 false。
轉置
另請參閱 tf.reshape
作業相關說明。
Transpose(operand)
引數 | 類型 | 語義 |
---|---|---|
operand |
XlaOp |
要轉置的運算元。 |
permutation |
ArraySlice<int64> |
如何控制維度。 |
依指定排列方式排列運算元維度,因此 ∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]
。
這與 Reshape(運算元, 排列, Permute(permutation, operand.shape.dimensions) 相同。
TriangularSolve
另請參閱 XlaBuilder::TriangularSolve
。
透過前移或反向替代來解開具有較低或上三角形係數矩陣的線性方程式。這個處理常式會根據主要維度播送,解決 a
和 b
為變數 x
的其中一個矩陣系統 op(a) * x =
b
或 x * op(a) = b
,其中 op(a)
為 op(a) = a
、op(a) = Transpose(a)
或 op(a) = Conj(Transpose(a))
。
TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)
引數 | 類型 | 語義 |
---|---|---|
a |
XlaOp |
排名 > 2 的陣列或浮點類型 (形狀為 [..., M, M] ) 的陣列。 |
b |
XlaOp |
排名 > 2 的相同類型陣列,形狀為 [..., M, K] 時,如果 left_side 為 true,則值為 [..., K, M] 。 |
left_side |
bool |
表示要解開 op(a) * x = b (true ) 或 x * op(a) = b (false ) 形式的系統。 |
lower |
bool |
是否要使用 a 的上三角形。 |
unit_diagonal |
bool |
如果為 true ,則 a 的對角元素會假設為 1 ,而非存取。 |
transpose_a |
Transpose |
決定是否要依原樣使用 a 、轉置或接受共通轉置。 |
系統會根據 lower
的值,僅從 a
的上下三角形讀取輸入資料。系統會忽略其他三角形的值。輸出資料會以同一個三角形傳回,其他三角形的值則是由實作定義,可以是任何內容。
如果 a
和 b
的排名大於 2,系統會將其視為矩陣批次,其中除次要 2 維度以外的所有維度都是批次維度。a
和 b
的批次維度必須相同。
組合式
另請參閱 XlaBuilder::Tuple
。
包含數量不同資料控點的元組,每個控點都有各自的形狀。
這類似於 C++ 中的 std::tuple
。概念:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
您可以透過 GetTupleElement
作業解構 (存取) 元組。
儘管
另請參閱 XlaBuilder::While
。
While(condition, body, init)
引數 | 類型 | 語義 |
---|---|---|
condition |
XlaComputation |
T -> PRED 類型的 XlaComputation 可定義迴圈的終止條件。 |
body |
XlaComputation |
定義迴圈主體的 XlaComputation 類型為 T -> T , |
init |
T |
condition 和 body 參數的初始值。 |
依序執行 body
,直到 condition
失敗為止。這與一般使用許多其他語言的迴圈類似,唯下列差異和限制除外。
While
節點會傳回T
類型的值,也就是上次執行body
的結果。T
類型的形狀是靜態決定,在所有疊代中都必須相同。
在第一個疊代作業中,計算的 T 參數會使用 init
值進行初始化,並在每個後續疊代中,從 body
自動更新為新結果。
While
節點的主要用途之一,是實作類神經網路中的訓練重複執行。簡化的虛擬程式碼如下所示,並以圖表表示運算作業。您可以在 while_test.cc
中找到程式碼。在這個範例中,T
類型是 Tuple
,其中包含疊代計數的 int32
和加總的 vector[10]
。如果是 1000 次疊代作業,迴圈會持續為累計值加入常數向量。
// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
iteration = result(0) + 1;
new_vector = result(1) + constant_vector[10];
result = {iteration, new_vector};
}