以下では、XlaBuilder
インターフェースで定義されているオペレーションのセマンティクスについて説明します。通常、これらのオペレーションは、xla_data.proto
の RPC インターフェースで定義されたオペレーションに 1 対 1 でマッピングされます。
命名規則に関する注記: XLA が扱う一般化データ型は、ある一様な型(32 ビット浮動小数点数など)の要素を保持する N 次元配列です。このドキュメントでは、任意の次元の配列を表すために「配列」という用語を使用します。便宜上、特殊なケースにはより具体的でわかりやすい名前が付けられています。たとえば、ベクトルは 1 次元配列で、行列は 2 次元配列です。
AfterAll
XlaBuilder::AfterAll
もご覧ください。
AfterAll は可変数のトークンを受け取って、単一のトークンを生成します。トークンはプリミティブ タイプであり、副作用のあるオペレーション間でスレッド化されて順序付けを適用できます。AfterAll
は、set オペレーションの後にオペレーションを並べ替えるためのトークンの結合として使用できます。
AfterAll(operands)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
XlaOp |
トークンの可変数 |
AllGather
XlaBuilder::AllGather
もご覧ください。
レプリカ間で連結を実行します。
AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
channel_id)
引数 | タイプ | セマンティクス |
---|---|---|
operand
|
XlaOp
|
レプリカ間で連結する配列 |
all_gather_dim |
int64 |
連結ディメンション |
replica_groups
|
int64 のベクトルのベクトル |
結合が行われるグループ間で |
channel_id
|
省略可 int64 |
モジュール間通信用のオプションのチャンネル ID |
replica_groups
は、連結が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID はReplicaId
を使用して取得できます)。各グループ内のレプリカの順序によって、結果に入力が配置される順序が決まります。replica_groups
は空であるか(この場合、すべてのレプリカが0
~N - 1
の順に 1 つのグループに属します)、レプリカの数と同じ数の要素を含める必要があります。たとえば、replica_groups = {0, 2}, {1, 3}
はレプリカ0
と2
、1
と3
の連結を実行します。shard_count
は、各レプリカ グループのサイズです。これは、replica_groups
が空の場合に必要です。channel_id
はモジュール間通信に使用されます。同じchannel_id
を持つall-gather
オペレーションのみが相互に通信できます。
出力シェイプは、all_gather_dim
が shard_count
倍に拡大された入力シェイプです。たとえば、2 つのレプリカがあり、オペランドの値が 2 つのレプリカにそれぞれ [1.0, 2.5]
と [3.0, 5.25]
の場合、all_gather_dim
が 0
であるこの op の出力値は、両方のレプリカで [1.0, 2.5, 3.0,
5.25]
になります。
AllReduce
XlaBuilder::AllReduce
もご覧ください。
レプリカ間でカスタム計算を実行します。
AllReduce(operand, computation, replica_group_ids, channel_id)
引数 | タイプ | セマンティクス |
---|---|---|
operand
|
XlaOp
|
レプリカ間で減算する配列または配列の空でないタプル |
computation |
XlaComputation |
削減の計算 |
replica_groups
|
int64 のベクトルのベクトル |
削減を実行するグループ |
channel_id
|
省略可 int64 |
モジュール間通信用のオプションのチャネル ID |
operand
が配列のタプルの場合、タプルの各要素に対してオール リデュークが実行されます。replica_groups
は、削減が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID はReplicaId
を使用して取得できます)。replica_groups
は空であるか(この場合、すべてのレプリカが 1 つのグループに属します)、レプリカの数と同じ数の要素を含める必要があります。たとえば、replica_groups = {0, 2}, {1, 3}
はレプリカ0
と2
、1
と3
の間で減算を行います。channel_id
はモジュール間の通信に使用されます。同じchannel_id
を持つall-reduce
オペレーションのみが相互に通信できます。
出力シェイプは入力シェイプと同じです。たとえば、2 つのレプリカがあり、2 つのレプリカでオペランドの値がそれぞれ [1.0, 2.5]
と [3.0, 5.25]
の場合、このオペレーションと合計計算の出力値は、両方のレプリカで [4.0, 7.75]
になります。入力がタプルの場合、出力もタプルになります。
AllReduce
の結果を計算するには、各レプリカから 1 つの入力が必要になるため、あるレプリカが別のレプリカよりも AllReduce
ノードを複数回実行すると、前のレプリカは永久に待機します。レプリカはすべて同じプログラムを実行しているため、このような状態が発生する方法は多くありませんが、while ループの条件がインフィードのデータに依存し、インフィードされたデータによって、一方のレプリカでもう一方のレプリカよりも while ループがより多くの回数反復される場合は発生する可能性があります。
AllToAll
XlaBuilder::AllToAll
もご覧ください。
AllToAll は、すべてのコアからすべてのコアにデータを送信する一括オペレーションです。次の 2 つのフェーズがあります。
- 散布フェーズ。各コアで、オペランドは
split_dimensions
に沿ってsplit_count
個のブロックに分割され、ブロックはすべてのコアに分散されます(たとえば、i 番目のブロックは i 番目のコアに送信されます)。 - 収集フェーズ。各コアは、受信したブロックを
concat_dimension
に沿って連結します。
参加するコアは、次の方法で構成できます。
replica_groups
: 各 ReplicaGroup には、計算に参加するレプリカ ID のリストが含まれています(現在のレプリカのレプリカ ID はReplicaId
を使用して取得できます)。AllToAll は、指定された順序でサブグループ内に適用されます。たとえば、replica_groups = { {1,2,3}, {4,5,0} }
は、AllToAll がレプリカ{1, 2, 3}
内と集約フェーズで適用され、受信したブロックが 1、2、3 の順序で連結されることを意味します。次に、レプリカ 4、5、0 内に別の AllToAll が適用され、連結順序も 4、5、0 になります。replica_groups
が空の場合、すべてのレプリカは、出現順に連結されて 1 つのグループに属します。
前提条件:
split_dimension
のオペランドのディメンション サイズはsplit_count
で割り切れます。- オペランドの形がタプルではありません。
AllToAll(operand, split_dimension, concat_dimension, split_count,
replica_groups)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
n 次元の入力配列 |
split_dimension
|
int64
|
演算対象が分割されるディメンションの名前を指定する [0,
n) の範囲内の値 |
concat_dimension
|
int64
|
分割ブロックを連結するディメンションの名前を指定する [0,
n) の範囲内の値 |
split_count
|
int64
|
このオペレーションに参加するコアの数。replica_groups が空の場合は、レプリカ数にする必要があります。それ以外の場合は、各グループのレプリカ数にする必要があります。 |
replica_groups
|
ReplicaGroup ベクトル
|
各グループにはレプリカ ID のリストが含まれています。 |
Alltoall の例を次に示します。
XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
この例では、Alltoall に参加しているコアは 4 つあります。各コアで、オペランドは次元 0 に沿って 4 つの部分に分割されるため、各部分の形状は f32[4,4] になります。4 つの部分がすべてのコアに分散されています。次に、各コアは、コア 0~4 の順序で、受け取った部分をディメンション 1 に沿って連結します。したがって、各コアの出力は f32[16,4] の形状になります。
BatchNormGrad
アルゴリズムの詳細については、XlaBuilder::BatchNormGrad
と元のバッチ ノーマライゼーションの論文もご覧ください。
バッチ正規化の勾配を計算します。
BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
正規化される n 次元の配列(x) |
scale |
XlaOp |
1 次元配列(\(\gamma\)) |
mean |
XlaOp |
1 次元配列(\(\mu\)) |
variance |
XlaOp |
1 次元の配列(\(\sigma^2\)) |
grad_output |
XlaOp |
BatchNormTraining (\(\nabla y\))に渡されるグラデーション |
epsilon |
float |
イプシロン値(\(\epsilon\)) |
feature_index |
int64 |
operand の特徴ディメンションに対するインデックス |
このオペレーションは、特徴ディメンション内の各特徴(feature_index
は operand
の特徴ディメンションのインデックス)について、他のすべてのディメンション全体で operand
、offset
、scale
に関する勾配を計算します。feature_index
は、operand
の特徴ディメンションの有効なインデックスである必要があります。
3 つのグラデーションは次の式で定義されます(4 次元配列を operand
、特徴ディメンションのインデックス l
、バッチサイズ m
、空間サイズ w
と h
の場合)。
\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]
入力 mean
と variance
は、バッチディメンションと空間ディメンションにわたるモーメント値を表します。
出力型は、3 つのハンドルのタプルです。
出力 | タイプ | セマンティクス |
---|---|---|
grad_operand
|
XlaOp
|
入力 operand に対する勾配($\nabla x$) |
grad_scale
|
XlaOp
|
入力 scale に関する勾配($\nabla\gamma$) |
grad_offset
|
XlaOp
|
入力 offset に対する勾配($\nabla \beta$) |
BatchNormInference
アルゴリズムの詳細については、XlaBuilder::BatchNormInference
と元のバッチ ノーマライゼーションの論文もご覧ください。
バッチディメンションと空間ディメンション全体で配列を正規化します。
BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
正規化する N 次元配列 |
scale |
XlaOp |
1 次元配列 |
offset |
XlaOp |
1 次元配列 |
mean |
XlaOp |
1 次元配列 |
variance |
XlaOp |
1 次元配列 |
epsilon |
float |
イプシロン値 |
feature_index |
int64 |
operand の特徴量ディメンションのインデックス |
このオペレーションは、特徴ディメンション内の各特徴(feature_index
は operand
の特徴ディメンションのインデックス)について、他のすべてのディメンションの平均と分散を計算し、その平均と分散を使用して operand
の各要素を正規化します。feature_index
は、operand
の特徴ディメンションの有効なインデックスである必要があります。
BatchNormInference
は、バッチごとに mean
と variance
を計算せずに BatchNormTraining
を呼び出す場合と同じです。代わりに、入力 mean
と variance
が推定値として使用されます。このオペレーションの目的は推論のレイテンシを短縮することです。そのため、BatchNormInference
という名前が付けられています。
出力は、入力 operand
と同じ形状の N 次元の正規化配列です。
BatchNormTraining
アルゴリズムの詳細な説明については、XlaBuilder::BatchNormTraining
と the original batch normalization paper
もご覧ください。
バッチ ディメンションと空間ディメンション全体で配列を正規化します。
BatchNormTraining(operand, scale, offset, epsilon, feature_index)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
正規化される n 次元の配列(x) |
scale |
XlaOp |
1 次元配列(\(\gamma\)) |
offset |
XlaOp |
1 次元の配列(\(\beta\)) |
epsilon |
float |
イプシロン値(\(\epsilon\)) |
feature_index |
int64 |
operand の特徴ディメンションに対するインデックス |
このオペレーションは、特徴ディメンション内の各特徴(feature_index
は operand
の特徴ディメンションのインデックス)について、他のすべてのディメンションの平均と分散を計算し、その平均と分散を使用して operand
の各要素を正規化します。feature_index
は、operand
の特徴ディメンションの有効なインデックスである必要があります。
operand
\(x\) 内の各バッチで、空間ディメンションのサイズが w
と h
の m
要素を含むバッチについて、アルゴリズムは次のように動作します(operand
が 4 次元配列であると仮定します)。
特徴ディメンション内の各特徴
l
のバッチ平均 \(\mu_l\) を計算します。 \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)バッチ分散を計算します \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$
正規化、スケーリング、シフト: \(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)
イプシロン値(通常は小数値)は、ゼロ除算エラーを回避するために追加されます。
出力型は、3 つの XlaOp
のタプルです。
出力 | タイプ | セマンティクス |
---|---|---|
output
|
XlaOp
|
入力 operand (y)と同じ形状の n 次元配列 |
batch_mean |
XlaOp |
1 次元配列(\(\mu\)) |
batch_var |
XlaOp |
1 次元配列(\(\sigma^2\)) |
batch_mean
と batch_var
は、上記の数式を使用してバッチ ディメンションと空間ディメンション全体で計算されたモーメントです。
BitcastConvertType
XlaBuilder::BitcastConvertType
もご覧ください。
TensorFlow の tf.bitcast
と同様に、データ シェイプからターゲット シェイプへの要素ごとのビットキャスト オペレーションを実行します。入力と出力のサイズは一致している必要があります。たとえば、s32
要素はビットキャスト ルーチンを介して f32
要素になり、1 つの s32
要素は 4 つの s8
要素になります。ビットキャスト は低レベルのキャストとして実装されているため、浮動小数点表現が異なるマシンでは異なる結果が得られます。
BitcastConvertType(operand, new_element_type)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
ディメンション D を持つ T 型の配列 |
new_element_type |
PrimitiveType |
タイプ U |
オペランドとターゲット シェイプの寸法は一致している必要があります。ただし、最後の寸法は変換前と変換後のプリミティブ サイズの比率で変化します。
ソースと宛先の要素の型をタプルにすることはできません。
幅の異なるプリミティブ型にビットキャスト変換する
BitcastConvert
HLO 命令は、出力要素型 T'
のサイズが入力要素 T
のサイズと等しくないケースをサポートしています。オペレーション全体は概念的にはビットキャストであり、基盤となるバイトは変更されないため、出力要素の形状を変更する必要があります。B = sizeof(T), B' =
sizeof(T')
については、次の 2 つのケースが考えられます。
まず、B > B'
の場合、出力シェイプにサイズ B/B'
の新しい最小ディメンションが追加されます。例:
f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)
有効なスカラーでも、このルールは同じです。
f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)
B' > B
の場合、この命令では入力シェイプの最後の論理ディメンションが B'/B
に等しくする必要があります。このディメンションは変換中に破棄されます。
f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)
異なるビット幅間の変換は要素単位ではなく、
ブロードキャスト
XlaBuilder::Broadcast
もご覧ください。
配列内のデータを複製してディメンションを配列に追加します。
Broadcast(operand, broadcast_sizes)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
複製する配列 |
broadcast_sizes |
ArraySlice<int64> |
新しいディメンションのサイズ |
新しいディメンションは左側に挿入されます。つまり、broadcast_sizes
の値が {a0, ..., aN}
で、オペランドシェイプのディメンションが {b0, ..., bM}
の場合、出力のシェイプのディメンションは {a0, ..., aN, b0, ..., bM}
になります。
新しいディメンションは、オペランドのコピーに対するインデックス付けを行います。つまり、
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
たとえば、operand
が値が 2.0f
のスカラー f32
で、broadcast_sizes
が {2, 3}
の場合、結果は形 f32[2, 3]
の配列になり、結果のすべての値は 2.0f
になります。
BroadcastInDim
XlaBuilder::BroadcastInDim
もご覧ください。
配列内のデータを複製して、配列のサイズとランクを拡張します。
BroadcastInDim(operand, out_dim_size, broadcast_dimensions)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
複製する配列 |
out_dim_size |
ArraySlice<int64> |
ターゲット シェイプのサイズ |
broadcast_dimensions |
ArraySlice<int64> |
オペランドシェイプの各次元に対応する、ターゲットシェイプのディメンション |
Broadcast に似ていますが、任意の場所にディメンションを追加し、サイズ 1 の既存のディメンションを拡張できます。
operand
は、out_dim_size
で記述されたシェイプにブロードキャストされます。broadcast_dimensions
は、operand
のディメンションをターゲット シェイプのディメンションにマッピングします。つまり、オペランドの i 番目のディメンションは、出力シェイプの broadcast_dimension[i] 番目のディメンションにマッピングされます。operand
のディメンションのサイズは 1 にするか、マッピング先の出力シェイプのディメンションと同じサイズにする必要があります。残りのディメンションは、サイズ 1 のディメンションで埋められます。不完全なディメンション ブロードキャストでは、これらの不完全なディメンションに沿ってブロードキャストを行い、出力シェイプに到達します。セマンティクスについては、ブロードキャスト ページで詳しく説明しています。
電話
XlaBuilder::Call
もご覧ください。
指定された引数を使用して計算を呼び出します。
Call(computation, args...)
引数 | タイプ | セマンティクス |
---|---|---|
computation |
XlaComputation |
任意の型の N 個のパラメータを使用した T_0, T_1, ..., T_{N-1} -> S 型の計算 |
args |
N 個の XlaOp のシーケンス |
任意の型の N 個の引数 |
args
の arity と型は、computation
のパラメータと一致する必要があります。args
は指定できません。
Cholesky
XlaBuilder::Cholesky
もご覧ください。
対称(ヘルミチアン)正定値行列のバッチの Cholesky 分解を計算します。
Cholesky(a, lower)
引数 | タイプ | セマンティクス |
---|---|---|
a |
XlaOp |
複素型または浮動小数点型のランクが 2 より大きい配列。 |
lower |
bool |
a の上三角形と下三角形のどちらを使用するかを指定します。 |
lower
が true
の場合、$a = l となるように低三角行列 l
を計算します。l^T$ です。lower
が false
の場合、\(a = u^T . u\)となる上三角行列 u
を計算します。
入力データは、lower
の値に応じて、a
の下部または上部の三角形からのみ読み取られます。他の三角形の値は無視されます。出力データは同じ三角形で返されます。他の三角形の値は実装定義であり、任意の値にすることができます。
a
のランクが 2 より大きい場合、a
は行列のバッチとして扱われ、マイナー 2 次元を除くすべてのディメンションはバッチ ディメンションです。
a
が対称(ヘルミチアン)正定値でない場合、結果は実装定義です。
クランプ
XlaBuilder::Clamp
もご覧ください。
オペランドを最小値と最大値の範囲内にクランプします。
Clamp(min, operand, max)
引数 | タイプ | セマンティクス |
---|---|---|
min |
XlaOp |
T 型の配列 |
operand |
XlaOp |
型 T の配列 |
max |
XlaOp |
型 T の配列 |
オペランド、最小値、最大値が与えられると、最小値と最大値の間の範囲内にある場合はオペランドを返します。それ以外の場合は、オペランドがこの範囲を下回っている場合は最小値を返し、オペランドがこの範囲より大きい場合は最大値を返します。つまり、clamp(a, x, b) = min(max(a, x), b)
のようになります。
3 つの配列はすべて同じ形状である必要があります。または、制限付きのブロードキャストとして、min
または max
を T
型のスカラーにすることもできます。
スカラー min
と max
を使用した例:
let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};
閉じる
XlaBuilder::Collapse
と tf.reshape
オペレーションもご覧ください。
配列の次元を 1 つの次元に折りたたみます。
Collapse(operand, dimensions)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
型 T の配列 |
dimensions |
int64 ベクトル |
T のディメンションの順番に並んだ連続したサブセット。 |
折りたたみは、オペランドのディメンションの指定のサブセットを単一のディメンションに置き換えます。入力引数は、T 型の任意の配列と、次元インデックスのビルド時定数ベクトルです。ディメンション インデックスは、T のディメンションの連続するサブセットで、順序付き(ディメンション番号が低い順)である必要があります。したがって、{0, 1, 2}、{0, 1}、{1, 2} はすべて有効なディメンション セットですが、{1, 0} や {0, 2} は無効です。これらのディメンションは、置き換えるディメンションと同じ位置に、1 つの新しいディメンションで置き換えられます。新しいディメンションのサイズは、元のディメンションのサイズの積に等しくなります。dimensions
の最小ディメンション番号は、これらのディメンションを圧縮するループ ネストで最も変化が遅いディメンション(最もメジャー)であり、最大ディメンション番号は最も変化が速いディメンション(最もマイナー)です。一般的な折りたたみ順序が必要な場合は、tf.reshape
演算子をご覧ください。
たとえば、v を 24 個の要素からなる配列とします。
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};
// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };
// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };
CollectivePermute
XlaBuilder::CollectivePermute
もご覧ください。
CollectivePermute は、レプリカ間でデータを送受信する集約オペレーションです。
CollectivePermute(operand, source_target_pairs)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
n 次元の入力配列 |
source_target_pairs |
<int64, int64> ベクトル |
(source_replica_id、target_replica_id)ペアのリスト。ペアごとに、演算子はソース レプリカからターゲット レプリカに送信されます。 |
source_target_pair
には次の制限があります。
- 2 つのペアでターゲット レプリカ ID が同じで、ソース レプリカ ID が同じにならないようにします。
- レプリカ ID がどのペアでもターゲットでない場合、そのレプリカの出力は、入力と同じ形状の 0 で構成されるテンソルになります。
Concatenate
XlaBuilder::ConcatInDim
もご覧ください。
Concatenate は、複数の配列オペランドから配列を作成します。この配列は、各入力配列オペランドと同じランク(互いに同じランクである必要があります)で、指定された順序で引数を含みます。
Concatenate(operands..., dimension)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
N 個の XlaOp のシーケンス |
次元が [L0、L1、...] の型 T の N 個の配列。N >= 1 が必要です。 |
dimension |
int64 |
operands 間で連結するディメンションの名前の区間 [0, N) の値。 |
dimension
を除くすべてのディメンションは同じである必要があります。これは、XLA が「不揃い」な配列をサポートしていないためです。また、ランク 0 の値は連結できません(連結するディメンションに名前を付けることができないためです)。
1 次元の場合の例:
Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}
2 次元の例:
let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}
図:
条件文
XlaBuilder::Conditional
もご覧ください。
Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
引数 | タイプ | セマンティクス |
---|---|---|
pred |
XlaOp |
PRED 型のスカラー |
true_operand |
XlaOp |
型の引数 \(T_0\) |
true_computation |
XlaComputation |
\(T_0 \to S\)型の XlaComputation |
false_operand |
XlaOp |
型の引数 \(T_1\) |
false_computation |
XlaComputation |
型の XlaComputation \(T_1 \to S\) |
pred
が true
の場合は true_computation
を、pred
が false
の場合は false_computation
を実行し、結果を返します。
true_computation
は \(T_0\) 型の単一の引数を取り、同じ型の true_operand
で呼び出されます。false_computation
は \(T_1\) 型の単一の引数を取り、同じ型の false_operand
で呼び出されます。返される値の型(true_computation
と false_computation
)は同じである必要があります。
pred
の値に応じて、true_computation
と false_computation
のいずれか 1 つのみが実行されます。
Conditional(branch_index, branch_computations, branch_operands)
引数 | タイプ | セマンティクス |
---|---|---|
branch_index |
XlaOp |
S32 型のスカラー |
branch_computations |
N 個の XlaComputation のシーケンス |
\(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)型の XlaComputation |
branch_operands |
N XlaOp のシーケンス |
型 \(T_0 , T_1 , ..., T_{N-1}\)の引数 |
branch_computations[branch_index]
を実行し、結果を返します。branch_index
が 0 未満または N 以上である S32
の場合、branch_computations[N-1]
がデフォルトのブランチとして実行されます。
各 branch_computations[b]
は \(T_b\) 型の単一の引数を取り、同じ型である branch_operands[b]
で呼び出されます。各 branch_computations[b]
の戻り値の型は同じである必要があります。
branch_index
の値に応じて、branch_computations
のいずれか 1 つのみが実行されます。
コンバージョン(畳み込み)
XlaBuilder::Conv
もご覧ください。
ConvWithGeneralPadding と同じですが、パディングは SAME または VALID として簡略で指定します。SAME パディングでは、入力(lhs
)をゼロでパディングして、ストライドを考慮しない場合、出力の形状が入力と同じになるようにします。VALID パディングは、パディングなしを意味します。
ConvWithGeneralPadding(畳み込み)
XlaBuilder::ConvWithGeneralPadding
もご覧ください。
ニューラル ネットワークで使用される種類の畳み込みを計算します。ここで、畳み込みは n 次元の底領域間を移動する n 次元のウィンドウと考えることができ、ウィンドウの可能な位置ごとに計算が実行されます。
引数 | タイプ | セマンティクス |
---|---|---|
lhs |
XlaOp |
入力のランク n+2 配列 |
rhs |
XlaOp |
カーネル重みのランク n+2 配列 |
window_strides |
ArraySlice<int64> |
カーネル ストライドの N 次元配列 |
padding |
ArraySlice< pair<int64,int64>> |
(低、高)パディングの n-d 配列 |
lhs_dilation |
ArraySlice<int64> |
n-d lhs dilation factor array |
rhs_dilation |
ArraySlice<int64> |
n-d rhs 拡張係数配列 |
feature_group_count |
int64 | 特徴グループの数 |
batch_group_count |
int64 | バッチグループの数 |
n は空間ディメンションの数です。lhs
引数は、基本領域を表すランク n+2 の配列です。これを入力と呼びますが
rh は入力でもありますニューラル ネットワークでは、これらが入力アクティベーションです。n+2 ディメンションは次の順序で表示されます。
batch
: この次元の各座標は、畳み込み演算が行われる独立した入力を表します。z/depth/features
: ベースエリア内の各(y,x)位置には、このディメンションに格納されるベクトルが関連付けられています。spatial_dims
: ウィンドウが移動するベース領域を定義するn
空間ディメンションを記述します。
rhs
引数は、畳み込みフィルタ / カーネル / ウィンドウを記述するランク n+2 の配列です。ディメンションの順序は次のとおりです。
output-z
: 出力のz
ディメンション。input-z
: このディメンションのサイズにfeature_group_count
を掛けた値は、z
ディメンションのサイズ(lhs)に等しくする必要があります。spatial_dims
: ベースエリアを移動する N 次元ウィンドウを定義するn
空間ディメンションを記述します。
window_strides
引数は、空間ディメンションで畳み込みウィンドウのストライドを指定します。たとえば、最初の空間次元のストライドが 3 の場合、ウィンドウは最初の空間インデックスが 3 で割り切れる座標にのみ配置できます。
padding
引数は、ベース領域に適用するゼロ パディングの量を指定します。パディングの量は負の場合があります。負のパディングの絶対値は、畳み込み演算の前に指定されたディメンションから削除する要素の数を示します。padding[0]
はディメンション y
のパディングを指定し、padding[1]
はディメンション x
のパディングを指定します。各ペアでは、低いパディングが最初の要素、高いパディングが 2 番目の要素になります。低いパディングは低いインデックスの方向に適用され、高いパディングは高いインデックスの方向に適用されます。たとえば、padding[1]
が (2,3)
の場合、2 番目の空間ディメンションの左側に 2 つのゼロ、右側に 3 つのゼロがパディングされます。パディングを使用することは、畳み込みを行う前に同じゼロ値を入力(lhs
)に挿入することと同じです。
lhs_dilation
引数と rhs_dilation
引数は、各空間ディメンションの lhs と rhs に適用される拡張係数をそれぞれ指定します。空間ディメンションの拡張係数が d の場合、そのディメンションの各エントリの間に d-1 個の穴が暗黙的に配置され、配列のサイズが増加します。穴は no-op 値で埋められます。これは、畳み込みではゼロを意味します。
右辺の拡張は、アトラス畳み込みとも呼ばれます。詳しくは、tf.nn.atrous_conv2d
をご覧ください。左辺の拡張は転置畳み込みとも呼ばれます。詳しくは、tf.nn.conv2d_transpose
をご覧ください。
feature_group_count
引数(デフォルト値 1)は、グループ化された畳み込みに使用できます。feature_group_count
は、入力特徴次元と出力特徴次元の両方の除数である必要があります。feature_group_count
が 1 より大きい場合、概念的には、入力特徴と出力特徴のディメンションと rhs
出力特徴のディメンションが、多くの feature_group_count
グループに均等に分割され、各グループが連続する特徴のサブシーケンスで構成されます。入力特徴量のディメンション rhs
は、lhs
入力特徴量のディメンションを feature_group_count
で除算した値と等しくする必要があります(つまり、入力特徴量のグループのサイズがすでに設定されています)。i 番目のグループは、多くの個別の畳み込みの feature_group_count
を計算するために一緒に使用されます。これらの畳み込みの結果は、出力特徴ディメンションで連結されます。
深度畳み込みでは、feature_group_count
引数が入力特徴次元に設定され、フィルタが [filter_height, filter_width, in_channels, channel_multiplier]
から [filter_height, filter_width, 1, in_channels * channel_multiplier]
に再形成されます。詳しくは、tf.nn.depthwise_conv2d
をご覧ください。
batch_group_count
(デフォルト値 1)引数は、バックプロパゲーション中にグループ化されたフィルタに使用できます。batch_group_count
は、lhs
(入力)バッチ ディメンションのサイズの除数にする必要があります。batch_group_count
が 1 より大きい場合、出力バッチ ディメンションのサイズは input batch
/ batch_group_count
にする必要があります。batch_group_count
は、出力特徴サイズの除数にする必要があります。
出力シェイプの寸法は次のとおりです。
batch
: このディメンションのサイズにbatch_group_count
を掛けた値は、lhs のbatch
ディメンションのサイズと等しくする必要があります。z
: カーネルのoutput-z
と同じサイズ(rhs
)。spatial_dims
: 畳み込みウィンドウの有効な配置ごとに 1 つの値。
上の図は、batch_group_count
フィールドの仕組みを示しています。つまり、各 lhs バッチを batch_group_count
グループにスライスし、出力特徴に対しても同様のことを行います。次に、これらのグループごとにペアワイズ畳み込みを実行し、出力特徴次元に沿って出力を連結します。他のすべてのディメンション(特徴と空間)のオペレーション セマンティクスは同じです。
畳み込みウィンドウの有効な配置は、ストライドとパディング後のベース領域のサイズによって決まります。
畳み込みの動作を説明するには、2D 畳み込みを検討し、出力内の固定の batch
、z
、y
、x
座標を選択します。ここで、(y,x)
はベース領域内のウィンドウの角の位置です(空間の寸法の解釈に応じて、左上など)。ベース領域から取得した 2D ウィンドウが作成されました。各 2D ポイントは 1D ベクトルに関連付けられているため、3D ボックスが作成されます。畳み込みカーネルから、出力座標 z
を固定したので、3D ボックスも生成されます。2 つのボックスは同じ次元であるため、2 つのボックス間の要素ごとの積の合計を取得できます(ドット積と同様に)。これが出力値です。
output-z
が5 の場合、ウィンドウの各位置で、出力の z
ディメンションに 5 つの値が出力されます。これらの値は、畳み込みカーネルのどの部分が使用されるかによって異なります。output-z
座標ごとに使用される値の 3D ボックスが別にあります。つまり、5 つの個別の畳み込みで、それぞれに異なるフィルタを使用すると考えることができます。
以下は、パディングとストライディングを使用した 2 次元畳み込みの擬似コードです。
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
ConvertElementType
XlaBuilder::ConvertElementType
もご覧ください。
C++ の要素ごとの static_cast
と同様に、データ シェイプからターゲット シェイプへの要素ごとの変換オペレーションを実行します。ディメンションは一致している必要があります。変換は要素ごとです。たとえば、s32
要素は s32
から f32
への変換ルーティンを介して f32
要素になります。
ConvertElementType(operand, new_element_type)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
ディメンション D を持つ T 型の配列 |
new_element_type |
PrimitiveType |
タイプ U |
オペランドの寸法とターゲット シェイプは一致している必要があります。ソース要素と宛先要素の型はタプルにできません。
T=s32
から U=f32
への変換などの変換では、整数から最も近い値への丸めなど、整数から浮動小数点への正規化変換ルーチンが実行されます。
let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}
CrossReplicaSum
合計計算で AllReduce
を実行します。
CustomCall
XlaBuilder::CustomCall
もご覧ください。
計算内でユーザー提供の関数を呼び出します。
CustomCall(target_name, args..., shape)
引数 | タイプ | セマンティクス |
---|---|---|
target_name |
string |
関数名。このシンボル名をターゲットとする呼び出し命令が生成されます。 |
args |
N 個の XlaOp のシーケンス |
任意の型の N 個の引数。関数に渡されます。 |
shape |
Shape |
関数の出力形状 |
関数のシグネチャは、引数の arity や型に関係なく同じです。
extern "C" void target_name(void* out, void** in);
たとえば、CustomCall が次のように使用されているとします。
let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };
CustomCall("myfunc", {x, y}, f32[3x3])
myfunc
の実装の例を次に示します。
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
ユーザー提供の関数には副作用がなく、その実行がべき等である必要があります。
ドット
XlaBuilder::Dot
もご覧ください。
Dot(lhs, rhs)
引数 | タイプ | セマンティクス |
---|---|---|
lhs |
XlaOp |
T 型の配列 |
rhs |
XlaOp |
T 型の配列 |
このオペレーションの正確なセマンティクスは、オペランドのランクによって異なります。
入力 | 出力 | セマンティクス |
---|---|---|
ベクトル [n] dot ベクトル [n] |
スカラー | ベクトルのドット積 |
行列 [m x k] dot ベクトル [k] |
ベクトル [m] | 行列ベクトルの乗算 |
行列 [m x k] dot 行列 [k x n] |
行列 [m x n] | 行列乗算 |
このオペレーションは、lhs
の 2 番目のディメンション(またはランクが 1 の場合、最初のディメンション)と rhs
の最初のディメンションに対して積の総和を実行します。これらは「圧縮」されたディメンションです。lhs
と rhs
の収縮されたディメンションは同じサイズである必要があります。実際には、ベクトル間のドット積、ベクトル / 行列の乗算、行列 / 行列の乗算を実行するために使用できます。
DotGeneral
XlaBuilder::DotGeneral
もご覧ください。
DotGeneral(lhs, rhs, dimension_numbers)
引数 | タイプ | セマンティクス |
---|---|---|
lhs |
XlaOp |
T 型の配列 |
rhs |
XlaOp |
型 T の配列 |
dimension_numbers |
DotDimensionNumbers |
バッチ ディメンション数の縮小と |
ドットに似ていますが、lhs
と rhs
の両方に縮小およびバッチ ディメンション番号を指定できます。
DotDimensionNumbers フィールド | タイプ | セマンティクス |
---|---|---|
lhs_contracting_dimensions
|
繰り返し int64 | lhs 契約ディメンション番号 |
rhs_contracting_dimensions
|
繰り返し int64 | rhs の縮小次元番号 |
lhs_batch_dimensions
|
repeated int64 | lhs バッチ ディメンション番号 |
rhs_batch_dimensions
|
repeated int64 | rhs バッチ ディメンション番号 |
DotGeneral は、dimension_numbers
で指定された収縮ディメンションの積の合計を実行します。
lhs
と rhs
から関連付ける縮小寸法番号は同じである必要はありませんが、同じサイズである必要があります。
ディメンション番号を縮小する例:
lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }
rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);
DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }
lhs
と rhs
に関連付けられたバッチ ディメンション番号のディメンション サイズは同じである必要があります。
バッチ ディメンション番号の例(バッチサイズ 2、2x2 行列):
lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
入力 | 出力 | セマンティクス |
---|---|---|
[b0, m, k] dot [b0, k, n] |
[b0, m, n] | バッチ matmul |
[b0, b1, m, k] dot [b0, b1, k, n] |
[b0, b1, m, n] | バッチ matmul |
結果のディメンション番号はバッチ ディメンションから始まり、lhs
の非契約/非バッチ ディメンション、最後に rhs
の非契約/非バッチ ディメンションとなります。
DynamicSlice
XlaBuilder::DynamicSlice
もご覧ください。
DynamicSlice は、動的 start_indices
の入力配列からサブ配列を抽出します。各ディメンションのスライスのサイズは size_indices
で渡されます。これは、各ディメンションのスライス区間の終了ポイント([開始、開始 + サイズ])を指定します。start_indices
の形状は、ランク == 1 で、ディメンション サイズが operand
のランクと等しくする必要があります。
DynamicSlice(operand, start_indices, size_indices)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
型 T の N 次元配列 |
start_indices |
N 個の XlaOp のシーケンス |
各ディメンションのスライスの開始インデックスを含む N 個のスカラー整数のリスト。0 以上の値を指定してください。 |
size_indices |
ArraySlice<int64> |
各ディメンションのスライスサイズを含む N 個の整数のリスト。各値は 0 より大きく、開始値 + サイズがディメンション サイズ以下である必要があります。これにより、ディメンション サイズのモジュロでラップされるのを回避できます。 |
有効なスライス インデックスは、スライスを実行する前に [1, N)
内の各インデックス i
に次の変換を適用することで計算されます。
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
これにより、抽出されたスライスがオペランド配列に対して常に境界内になります。変換が適用される前にスライスが境界内にある場合、変換は適用されません。
1 次元の場合の例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}
DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}
2 次元の例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}
DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicUpdateSlice
XlaBuilder::DynamicUpdateSlice
もご覧ください。
DynamicUpdateSlice は、入力配列 operand
の値である結果を生成します。スライス update
は start_indices
で上書きされます。update
の形状によって、更新される結果のサブ配列の形状が決まります。start_indices
のシェイプは、ランク == 1 で、ディメンション サイズが operand
のランクと同じである必要があります。
DynamicUpdateSlice(operand, update, start_indices)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
T 型の N 次元配列 |
update |
XlaOp |
スライスの更新を含む T 型の N 次元配列。更新シェイプの各ディメンションは 0 より大きく、開始位置 + 更新量は各ディメンションのオペランドサイズ以下にする必要があります。これにより、範囲外の更新インデックスが生成されるのを回避できます。 |
start_indices |
N 個の XlaOp のシーケンス |
各ディメンションのスライスの開始インデックスを含む N 個のスカラー整数のリスト。0 以上の値を指定してください。 |
有効なスライス インデックスは、スライスを実行する前に [1, N)
内の各インデックス i
に次の変換を適用することで計算されます。
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
これにより、更新されたスライスはオペランド配列に対して常にインバウンドになります。変換が適用される前にスライスが境界内にある場合、変換は適用されません。
1 次元の場合の例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}
DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}
2 次元の例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0, 13.0},
{14.0, 15.0},
{16.0, 17.0} }
let s = {1, 1}
DynamicUpdateSlice(b, u, s) produces:
{ {0.0, 1.0, 2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }
要素ごとのバイナリ算術演算
XlaBuilder::Add
もご覧ください。
要素ごとのバイナリ算術演算のセットがサポートされています。
Op(lhs, rhs)
ここで、Op
は、Add
(加算)、Sub
(減算)、Mul
(乗算)、Div
(除算)、Pow
(べき乗)、Rem
(余り)、Max
(最大値)、Min
(最小)、And
(論理 AND)、Or
(論理 OR)、Xor
(論理 XOR)、ShiftLeft
(左シフト)、ShiftRightArithmetic
(論理 1 または右シフト)、{14/2 項または右シフト)、{14/2 項または右シフト}(ShiftRightArithmetic
ShiftRightLogical
Atan2
Complex
引数 | タイプ | セマンティクス |
---|---|---|
lhs |
XlaOp |
左辺オペランド: T 型の配列 |
rhs |
XlaOp |
右側のオペランド: T 型の配列 |
引数のシェイプは類似しているか、互換性がある必要があります。シェイプの互換性について詳しくは、ブロードキャストのドキュメントをご覧ください。演算の結果は、2 つの入力配列をブロードキャストした結果である形状を持ちます。このバリアントでは、オペランドのいずれかがスカラーでない限り、異なるランクの配列間の演算はサポートされていません。
Op
が Rem
の場合、結果の符号が被除数から取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。
整数の除算オーバーフロー(ゼロによる符号付き / 符号なしの除算 / 余り、または INT_SMIN
と -1
の符号付き除算 / 余り)は、実装定義の値を生成します。
次のオペレーションには、異なるランクのブロードキャストをサポートする代替バリアントがあります。
Op(lhs, rhs, broadcast_dimensions)
ここで、Op
は上記と同じです。この演算のバリアントは、異なるランクの配列間の算術演算(ベクトルに行列を追加するなど)に使用します。
追加の broadcast_dimensions
オペランドは、低ランクのオペランドのランクを高ランクのオペランドのランクまで拡張するために使用される整数のスライスです。broadcast_dimensions
は、下位ランクのシェイプのディメンションを上位ランクのシェイプのディメンションにマッピングします。拡張されたシェイプのマッピングされていないディメンションは、サイズ 1 のディメンションで埋められます。縮退ディメンションのブロードキャストは、これらの縮退ディメンションに沿ってシェイプをブロードキャストし、両方のオペランドのシェイプを均等にします。セマンティクスの詳細については、ブロードキャスト ページをご覧ください。
要素ごとの比較演算
XlaBuilder::Eq
もご覧ください。
標準的な要素ごとのバイナリ比較演算のセットがサポートされています。浮動小数点型を比較するときは、標準の IEEE 754 浮動小数点比較セマンティクスが適用されます。
Op(lhs, rhs)
ここで、Op
は、Eq
(等しい)、Ne
(等しくない)、Ge
(以上)、Gt
(より大きい)、Le
(以下)、Lt
(より小さい)のいずれかです。別の演算子セットである EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder、LtTotalOrder は、-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN を適用することで、浮動小数点数に対する全順序をサポートする点を除き、同じ機能を提供します。
引数 | タイプ | セマンティクス |
---|---|---|
lhs |
XlaOp |
左辺オペランド: T 型の配列 |
rhs |
XlaOp |
右側のオペランド: T 型の配列 |
引数のシェイプは類似しているか、互換性がある必要があります。シェイプの互換性について詳しくは、ブロードキャストのドキュメントをご覧ください。演算の結果のシェイプは、要素型 PRED
の 2 つの入力配列をブロードキャストした結果です。このバリアントでは、オペランドの 1 つがスカラーでない限り、異なるランクの配列間の演算はサポートされていません。
以下のオペレーションには、異なるランクのブロードキャストをサポートする代替バリアントがあります。
Op(lhs, rhs, broadcast_dimensions)
ここで、Op
は上記と同じです。この演算のバリエーションは、異なるランクの配列間の比較演算(ベクトルに行列を追加するなど)に使用する必要があります。
追加の broadcast_dimensions
オペランドは、オペランドのブロードキャストに使用するディメンションを指定する整数のスライスです。セマンティクスの詳細については、ブロードキャスト ページをご覧ください。
要素ごとの単項関数
XlaBuilder では、次の要素単位の単項関数がサポートされています。
Abs(operand)
要素単位の絶対値 x -> |x|
。
Cbrt(operand)
要素単位の立方根演算の x -> cbrt(x)
。
Ceil(operand)
要素単位のセル x -> ⌈x⌉
。
Clz(operand)
要素ごとに先頭のゼロをカウントします。
Cos(operand)
要素ごとのコサイン x -> cos(x)
。
Erf(operand)
要素ごとの誤差関数 x -> erf(x)
。
\(\text{erf}(x) = \frac{2}{\sqrt{\pi} }\int_0^x e^{-t^2} \, dt\)。
Exp(operand)
要素単位の自然指数 x -> e^x
。
Expm1(operand)
要素ごとの自然指数マイナス 1 の x -> e^x - 1
。
Floor(operand)
要素ごとの下限 x -> ⌊x⌋
。
Imag(operand)
: 複雑な(または実数)形状の要素単位の虚数部。x -> imag(x)
。オペランドが浮動小数点型の場合は 0 を返します。
IsFinite(operand)
operand
の各要素が有限である(正または負の無限大ではなく、NaN
でもない)かどうかをテストします。入力と同じ形状の PRED
値の配列を返します。対応する入力要素が有限である場合にのみ、各要素が true
になります。
Log(operand)
要素ごとの自然対数 x -> ln(x)
。
Log1p(operand)
要素ごとのシフトされた自然対数 x -> ln(1+x)
。
Logistic(operand)
要素ごとのロジスティック関数の計算 x ->
logistic(x)
。
Neg(operand)
要素単位の否定 x -> -x
。
Not(operand)
要素ごとの論理否定 x -> !(x)
。
PopulationCount(operand)
operand
の各要素に設定されたビット数を計算します。
Real(operand)
複素(または実数)シェイプの要素ごとの実部。x -> real(x)
。オペランドが浮動小数点型の場合は、同じ値を返します。
Round(operand)
要素単位の丸め処理で、0 から離れた位置に戻ります。
RoundNearestEven(operand)
要素ごとの丸め、最も近い偶数に丸められます。
Rsqrt(operand)
平方根演算 x -> 1.0 / sqrt(x)
の要素単位の逆数。
Sign(operand)
要素単位の符号演算 x -> sgn(x)
\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]
operand
の要素タイプの比較演算子を使用して。
Sin(operand)
要素ごとの正弦 x -> sin(x)
。
Sqrt(operand)
要素単位の平方根演算 x -> sqrt(x)
。
Tan(operand)
要素ごとの接線 x -> tan(x)
。
Tanh(operand)
要素ごとの双曲線正接 x -> tanh(x)
。
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
関数のオペランド |
この関数は operand
配列内の各要素に適用され、同じ形状の配列になります。operand
をスカラー(ランク 0)にすることができます。
Fft
XLA FFT オペレーションは、実数と複素数の入力 / 出力に対して、正規化と逆正規化の 2 つのフーリエ変換を実装します。最大 3 つの軸での多変量 FFT がサポートされています。
XlaBuilder::Fft
もご覧ください。
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
フーリエ変換する配列。 |
fft_type |
FftType |
下の表をご覧ください。 |
fft_length |
ArraySlice<int64> |
変換される軸の時間領域の長さ。これは特に、IRFFT で最内側の軸のサイズを適正化するために必要です(RFFT(fft_length=[16]) の出力形状が RFFT(fft_length=[17]) と同じであるため)。 |
FftType |
セマンティクス |
---|---|
FFT |
複素数から複素数への FFT を前方処理します。形状は変更されません。 |
IFFT |
複素数逆数 FFT。形状は変更されません。 |
RFFT |
実数から複素数への FFT を前方処理します。fft_length[-1] がゼロ以外の値の場合、最内側の軸の形状は fft_length[-1] // 2 + 1 に縮小され、ナイキスト周波数を超える変換された信号の逆コンジュゲート部分が省略されます。 |
IRFFT |
逆実数から複素数への変換 FFT(つまり、複素数を受け取り、実数を返します)。fft_length[-1] がゼロ以外の値の場合、最も内側の軸の形状は fft_length[-1] に拡張され、変換された信号のうちナイキスト周波数を超える部分を 1 エントリの逆共役から fft_length[-1] // 2 + 1 エントリに推測します。 |
多次元 FFT
複数の fft_length
が指定されている場合、最も内側の軸のそれぞれに FFT 演算のカスケードを適用するのと同じ結果になります。実数から複素数への変換と複素数から実数への変換の場合、最も内側の軸変換が(実質的に)最初に実行されます(RFFT の場合、IRFFT の場合は最後)。そのため、最も内側の軸のサイズが変更されます。他の軸変換は複雑 -> 複雑になります
実装の詳細
CPU FFT は、Eigen の TensorFFT をベースにしています。GPU FFT は cuFFT を使用します。
収集
XLA 収集演算は、入力配列の複数のスライス(異なる可能性のあるランタイム オフセットで各スライス)をつなぎ合わせます。
一般セマンティクス
XlaBuilder::Gather
もご覧ください。より直感的な説明については、以下の「非公式の説明」セクションをご覧ください。
gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
収集元の配列。 |
start_indices |
XlaOp |
収集するスライスの開始インデックスを含む配列。 |
index_vector_dim |
int64 |
開始インデックスが「含まれる」start_indices 内のディメンション。詳細な説明は以下をご覧ください。 |
offset_dims |
ArraySlice<int64> |
オペランドからスライスされた配列にオフセットする出力シェイプ内のディメンションのセット。 |
slice_sizes |
ArraySlice<int64> |
slice_sizes[i] は、ディメンション i のスライスの境界です。 |
collapsed_slice_dims |
ArraySlice<int64> |
各スライスで閉じられたディメンションのセット。これらのディメンションのサイズは 1 にする必要があります。 |
start_index_map |
ArraySlice<int64> |
start_indices のインデックスをオペランドの有効なインデックスにマッピングする方法を表すマップ。 |
indices_are_sorted |
bool |
呼び出し元によってインデックスが並べ替えられることが保証されているかどうか。 |
便宜上、出力配列内のディメンションには offset_dims
ではなく batch_dims
というラベルを付けています。
出力は、ランク batch_dims.size
+ offset_dims.size
の配列です。
operand.rank
は offset_dims.size
と collapsed_slice_dims.size
の合計と等しい必要があります。また、slice_sizes.size
は operand.rank
と等しい必要があります。
index_vector_dim
が start_indices.rank
と等しい場合、暗黙的に start_indices
は末尾の 1
ディメンションを持つと見なされます(つまり、start_indices
が [6,7]
の形状で、index_vector_dim
が 2
の場合、暗黙的に start_indices
の形状は [6,7,1]
とみなされます)。
ディメンション i
に沿った出力配列の境界は次のように計算されます。
i
がbatch_dims
に存在する場合(つまり、一部のk
に対してbatch_dims[k]
と等しい場合)、index_vector_dim
をスキップしてstart_indices.shape
から対応するディメンションの境界を選択します(つまり、k
<index_vector_dim
の場合はstart_indices.shape.dims
[k
] を選択し、それ以外の場合はstart_indices.shape.dims
[k
+1
] を選択します)。i
がoffset_dims
に存在する場合(つまり、あるk
に対してoffset_dims
[k
] と等しい場合)、collapsed_slice_dims
を考慮した後でslice_sizes
から対応する境界を選択します(つまり、adjusted_slice_sizes
[k
] を選択します。ここでadjusted_slice_sizes
は、インデックスcollapsed_slice_dims
の境界を削除したslice_sizes
です)。
正式には、特定の出力インデックス Out
に対応するオペランド インデックス In
は次のように計算されます。
batch_dims
のk
に対してG
= {Out
[k
] とします。G
を使用して、S
[i
] =start_indices
[Combine(G
,i
)] となるようにベクトルS
をスライスします。Combine(A, b) は、位置index_vector_dim
の b を A に挿入します。これは、G
が空であっても適切に定義されていることに注意してください。G
が空の場合は、S
=start_indices
になります。start_index_map
を使用してS
を分散し、S
を使用してoperand
に開始インデックスS
in
を作成します。具体的には次のようになります。S
in
[start_index_map
[k
]] =S
[k
](k
<start_index_map.size
の場合)。S
in
[_
] =0
(それ以外の場合)。
collapsed_slice_dims
セットに従ってOut
のオフセット ディメンションでインデックスを分散し、operand
にインデックスO
in
を作成します。具体的には次のようになります。k
<offset_dims.size
の場合、O
in
[remapped_offset_dims
(k
)] =Out
[offset_dims
[k
]](remapped_offset_dims
は後で定義されます)。O
in
[_
] =0
In
はO
in
+S
in
です。ここで、+ は要素単位の加算です。
remapped_offset_dims
は、ドメイン [0
, offset_dims.size
) と範囲 [0
, operand.rank
) \ collapsed_slice_dims
の単調関数です。たとえばoffset_dims.size
が 4
、operand.rank
が 6
、collapsed_slice_dims
が {0
, 2
} の場合、remapped_offset_dims
は {0
→1
,
1
→3
, 2
→4
, 3
→5
} です。
indices_are_sorted
が true に設定されている場合、XLA では start_indices
がユーザーによって(start_index_map
に従って値を分散させた後に)昇順で並べ替えられていると想定できます。そうでない場合、セマンティクスは実装定義です。
非公式の説明と例
非公式には、出力配列内のすべてのインデックス Out
は、次のように計算されるオペランド配列内の要素 E
に対応します。
Out
のバッチ ディメンションを使用して、start_indices
から開始インデックスを検索します。start_index_map
を使用して、開始インデックス(サイズが operand.rank より小さい場合があります)をoperand
の「完全な」開始インデックスにマッピングします。完全な開始インデックスを使用して、サイズが
slice_sizes
のスライスを動的にスライスします。collapsed_slice_dims
ディメンションを閉じて、スライスの形状を変更します。すべての圧縮スライス ディメンションに 1 の境界が設定されているため、この再シェイプは常に有効です。Out
のオフセット ディメンションを使用してこのスライスにインデックスを付け、出力インデックスOut
に対応する入力要素E
を取得します。
以降のすべての例では、index_vector_dim
は start_indices.rank
~1
に設定されています。index_vector_dim
に興味深い値を指定しても、オペレーションは根本的に変わりませんが、視覚的な表現が複雑になります。
上記のすべてがどのように連携しているかを直感的に理解するには、[16,11]
配列から 5 つの [8,6]
シェイプのスライスを収集する例を見てみましょう。[16,11]
配列内のスライスの位置は、形状 S64[2]
のインデックス ベクトルとして表すことができるため、5 つの位置のセットは S64[5,2]
配列として表現できます。
集計オペレーションの動作は、[G
、O
0
、O
1
] という出力シェイプのインデックスを受け取り、次のように入力配列内の要素にマッピングするインデックス変換として表すことができます。
まず、G
を使用して、集計インデックス アレイから(X
、Y
)ベクトルを選択します。出力配列のインデックス [G
,O
0
,O
1
] の要素は、入力配列のインデックス [X
+O
0
,Y
+O
1
] の要素になります。
slice_sizes
は [8,6]
です。これは、O0
と O1
の範囲を決定し、これがスライスの境界を決定します。
この収集オペレーションは、G
をバッチ ディメンションとするバッチ動的スライスとして機能します。
収集インデックスは多次元にすることができます。たとえば、上記の例のより一般的なバージョンでは、[4,5,2]
の形状の「集計インデックス」配列を使用して、次のようにインデックスを変換します。
この場合も、これはバッチ動的スライス G
0
、バッチ ディメンションとしての G
1
として機能します。スライスサイズは引き続き [8,6]
です。
XLA の集計オペレーションは、上記の非公式なセマンティクスを次のように一般化します。
出力シェイプのどのディメンションがオフセット ディメンションであるかを構成できます(最後の例の
O
0
、O
1
を含むディメンション)。出力バッチ ディメンション(最後の例ではG
0
、G
1
を含むディメンション)は、オフセット ディメンションではない出力ディメンションとして定義されます。出力シェイプに明示的に存在する出力オフセット ディメンションの数は、入力ランクよりも少なくなる場合があります。これらの「欠落している」ディメンション(
collapsed_slice_dims
として明示的にリストされている)は、スライスサイズが1
である必要があります。これらはスライスサイズが1
であるため、有効なインデックスは0
のみであり、これらを省略してもあいまいさは発生しません。「インデックスの収集」配列(最後の例の
X
、Y
)から抽出されたスライスは、入力配列ランクよりも要素が少ない場合があります。明示的なマッピングは、入力と同じランクにインデックスを拡張する方法を指定します。
最後の例として、(2)と(3)を使用して tf.gather_nd
を実装します。
G
0
と G
1
は、通常どおり、集計インデックス配列から開始インデックスをスライスするために使用されます。ただし、開始インデックスには X
という要素が 1 つだけあります。同様に、値が O
0
の出力オフセット インデックスは 1 つだけです。しかし、入力配列のインデックスとして使用されるインデックスとして使用される前に、入力配列のインデックスとして使用される前に、入力配列のインデックスとして使用される前に、入力配列のインデックスとして使用される前に、入力配列のインデックスとして使用される前に、入力配列のインデックスとして使用。start_index_map
remapped_offset_dims
X
X
0
0
0
0
0
0
0
0
0
O
O
O
O
G
G
G
G
1
1
GatherIndices
tf.gather_nd
このケースの slice_sizes
は [1,11]
です。直感的に、これは、集計インデックス 配列内のすべてのインデックス X
が行全体を選択し、結果がこれらの行の連結になることを意味します。
GetDimensionSize
XlaBuilder::GetDimensionSize
もご覧ください。
オペランドの指定されたディメンションのサイズを返します。オペランドは配列形式にする必要があります。
GetDimensionSize(operand, dimension)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
n 次元の入力配列 |
dimension |
int64 |
ディメンションを指定する区間 [0, n) の値 |
SetDimensionSize
XlaBuilder::SetDimensionSize
もご覧ください。
XlaOp の指定されたディメンションの動的サイズを設定します。オペランドは配列型である必要があります。
SetDimensionSize(operand, size, dimension)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
n 次元の入力配列。 |
size |
XlaOp |
実行時の動的サイズを表す int32。 |
dimension |
int64 |
ディメンションを指定する [0, n) の範囲内の値。 |
オペランドを結果として渡し、動的ディメンションをコンパイラが追跡します。
パディングされた値は、ダウンストリームの減算オペレーションによって無視されます。
let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;
// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);
// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);
GetTupleElement
XlaBuilder::GetTupleElement
もご覧ください。
コンパイル時定数値を持つタプルへのインデックス。
値はコンパイル時の定数にする必要があります。これにより、形状推論によって結果の値の型が決定されます。
これは C++ の std::get<int N>(t)
に似ています。概念的に:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
tf.tuple
もご覧ください。
Infeed
XlaBuilder::Infeed
もご覧ください。
Infeed(shape)
引数 | タイプ | セマンティクス |
---|---|---|
shape |
Shape |
インフィード インターフェースから読み取られるデータの形状。シェイプの layout フィールドは、デバイスに送信されるデータのレイアウトと一致するように設定する必要があります。一致しない場合、動作は未定義です。 |
デバイスの暗黙的なインフィード ストリーミング インターフェースから単一のデータアイテムを読み取り、データを指定されたシェイプとそのレイアウトとして解釈し、データの XlaOp
を返します。1 つの計算で複数のインフィード オペレーションを実行できますが、それらのオペレーションはすべての順序で並べる必要があります。たとえば、次のコードの 2 つの Infeed には合計順序があります。これは、while ループ間に依存関係があるためです。
result1 = while (condition, init = init_value) {
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
}
ネストされたタプルシェイプはサポートされていません。空のタプルシェイプの場合、インフィード オペレーションは実質的に何も行われず、デバイスのインフィードからデータを読み取らずに続行されます。
Iota
XlaBuilder::Iota
もご覧ください。
Iota(shape, iota_dimension)
ホスト転送のサイズが大きくなる可能性があるのではなく、デバイスに定数リテラルをビルドします。指定されたシェイプの配列を作成し、指定されたディメンションに沿ってゼロから 1 ずつ増加する値を保持します。浮動小数点型の場合、生成される配列は ConvertElementType(Iota(...))
と同等です。ここで、Iota
は整数型で、変換は浮動小数点型です。
引数 | タイプ | セマンティクス |
---|---|---|
shape |
Shape |
Iota() によって作成された配列の形状 |
iota_dimension |
int64 |
増分するディメンション。 |
たとえば、Iota(s32[4, 8], 0)
は
[[0, 0, 0, 0, 0, 0, 0, 0 ],
[1, 1, 1, 1, 1, 1, 1, 1 ],
[2, 2, 2, 2, 2, 2, 2, 2 ],
[3, 3, 3, 3, 3, 3, 3, 3 ]]
返品可能(返品手数料: Iota(s32[4, 8], 1)
)
[[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ]]
地図
XlaBuilder::Map
もご覧ください。
Map(operands..., computation)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
N 個の XlaOp のシーケンス |
T0..T{N-1} 型の N 配列 |
computation |
XlaComputation |
T 型の N 個のパラメータと任意の型の M のパラメータによる T_0, T_1, .., T_{N + M -1} -> S 型の計算 |
dimensions |
int64 配列 |
地図のサイズの配列 |
指定された operands
配列にスカラー関数を適用し、同じ次元の配列を生成します。各要素は、入力配列内の対応する要素に適用されたマッピングされた関数の結果です。
マッピングされた関数は任意の計算ですが、スカラー型 T
の入力が N 個、型が S
の単一の出力が含まれるという制限があります。出力の次元はオペランドと同じですが、要素の型 T が S で置き換えられています。
たとえば、Map(op1, op2, op3, computation, par1)
は入力配列内の各(多次元)インデックスで elem_out <-
computation(elem1, elem2, elem3, par1)
をマッピングして出力配列を生成します。
OptimizationBarrier
最適化パスをブロックして、バリアを越えて計算を移動できないようにします。
バリアの出力に依存する演算子の前にすべての入力が評価されるようにします。
パッド
XlaBuilder::Pad
もご覧ください。
Pad(operand, padding_value, padding_config)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
T 型の配列 |
padding_value |
XlaOp |
追加されたパディングを埋める T 型のスカラー |
padding_config |
PaddingConfig |
各寸法の両端(低、高)および要素間のパディング量 |
指定された operand
配列を拡張します。配列の周囲と配列の要素間に指定された padding_value
を挿入します。padding_config
には、各ディメンションのエッジ パディングと内部パディングの量を指定します。
PaddingConfig
は PaddingConfigDimension
の繰り返しフィールドで、各ディメンションに edge_padding_low
、edge_padding_high
、interior_padding
の 3 つのフィールドが含まれています。
edge_padding_low
と edge_padding_high
は、各ディメンションの下端(インデックス 0 の隣)と上端(最大インデックスの隣)に追加されるパディングの量をそれぞれ指定します。エッジ パディングの量は負の値にすることができます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。
interior_padding
は、各寸法の 2 つの要素間に追加されるパディングの量を指定します。負の値にすることはできません。内部パディングはエッジ パディングの前に論理的に発生するため、負のエッジ パディングの場合、要素は内部パディング オペランドから削除されます。
エッジ パディング ペアがすべて(0, 0)で、内部パディング値がすべて 0 の場合、このオペレーションは no-op です。次の図は、2 次元配列のさまざまな edge_padding
値と interior_padding
値の例を示しています。
Recv
XlaBuilder::Recv
もご覧ください。
Recv(shape, channel_handle)
引数 | タイプ | セマンティクス |
---|---|---|
shape |
Shape |
受信するデータの形状 |
channel_handle |
ChannelHandle |
各送受信ペアの一意の識別子 |
同じチャネル ハンドルを共有する別のコンピューティングの Send
命令から、指定されたシェイプのデータを受け取ります。受信したデータの XlaOp を返します。
Recv
オペレーションのクライアント API は、同期通信を表します。ただし、この命令は内部で 2 つの HLO 命令(Recv
と RecvDone
)に分解され、非同期のデータ転送が可能になります。HloInstruction::CreateRecv
と HloInstruction::CreateRecvDone
もご覧ください。
Recv(const Shape& shape, int64 channel_id)
同じ channel_id を持つ Send
命令からデータを受信するために必要なリソースを割り当てます。割り当てられたリソースのコンテキストを返します。これは、次の RecvDone
命令でデータ転送の完了を待つために使用されます。コンテキストは {受信バッファ(シェイプ)、リクエスト識別子(U32)} のタプルであり、RecvDone
命令でのみ使用できます。
RecvDone(HloInstruction context)
Recv
命令によって作成されたコンテキストを受け取り、データ転送が完了するまで待機し、受信したデータを返します。
削減
XlaBuilder::Reduce
もご覧ください。
リダクション関数を 1 つ以上の配列に並列に適用します。
Reduce(operands..., init_values..., computation, dimensions)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
N XlaOp のシーケンス |
T_0, ..., T_{N-1} 型の N 配列。 |
init_values |
N XlaOp のシーケンス |
T_0, ..., T_{N-1} 型の N 個のスカラー。 |
computation |
XlaComputation |
型 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) の計算。 |
dimensions |
int64 配列 |
削減するディメンションの順序なし配列。 |
ここで
- N は 1 以上の値にする必要があります。
- 計算は「おおまかな」結合的である必要があります(以下を参照)。
- すべての入力配列のディメンションは同じである必要があります。
- すべての初期値は、
computation
で ID を形成する必要があります。 N = 1
の場合、Collate(T)
はT
です。N > 1
の場合、Collate(T_0, ..., T_{N-1})
はT
型のN
要素のタプルです。
この演算により、各入力配列の 1 つ以上のディメンションがスカラーに縮小されます。返される各配列のランクは rank(operand) - len(dimensions)
です。op の出力は Collate(Q_0, ..., Q_N)
です。ここで、Q_i
は T_i
型の配列です。ディメンションについては後述します。
異なるバックエンドで、減算計算を再関連付けることができます。加算などの一部の減算関数は浮動小数点数に対して結合的ではないため、数値の差異が生じる可能性があります。ただし、データの範囲が制限されている場合、ほとんどの実用的な用途では、浮動小数点加算は結合性に非常に近くなります。
例
リダクション関数 f
(これは computation
)を使用して、値 [10, 11,
12, 13]
を持つ単一の 1D 配列で 1 つの次元を横断する場合、次のように計算できます。
f(10, f(11, f(12, f(init_value, 13)))
他にも次のような方法があります。
f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))
次の例は、初期値が 0 の減算計算として合計を使用する減算を実装する方法の概要を示しています。
result_shape <- remove all dims in dimensions from operand_shape
# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
# Initialize this result element
result[r0, r1...] <- 0
# Iterate over all the reduction dimensions
for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
# Increment the result element with the value of the operand's element.
# The index of the operand's element is constructed from all ri's and di's
# in the right order (by construction ri's and di's together index over the
# whole operand shape).
result[r0, r1...] += operand[ri... di]
2 次元配列(行列)を減らす例を次に示します。このシェイプは、ランク 2、サイズ 2 の次元 0、サイズ 3 の次元 1 です。
「加算」関数で次元 0 または 1 を削減した結果:
どちらの減算結果も 1 次元配列です。この図では 見やすくするために 1 つを列としもう 1 つを行として示しています
より複雑な例として、3D 配列を次に示します。ランクは 3、ディメンション 0 のサイズは 4、ディメンション 1 のサイズは 2、ディメンション 2 のサイズは 3 です。わかりやすくするため、値 1 ~ 6 はディメンション 0 全体に複製されます。
2D の例と同様に、1 つのディメンションのみを削減できます。たとえば、ディメンション 0 を削減すると、ディメンション 0 のすべての値がスカラーに折りたたまれたランク 2 配列が得られます。
| 4 8 12 |
| 16 20 24 |
ディメンション 2 を削減すると、ディメンション 2 のすべての値がスカラーに折りたたまれたランク 2 配列も得られます。
| 6 15 |
| 6 15 |
| 6 15 |
| 6 15 |
入力内の残りのディメンション間の相対順序は出力で保持されますが、一部のディメンションには新しい番号が割り当てられる場合があります(ランクが変更されるため)。
複数の次元を減らすこともできます。追加減算ディメンション 0 と 1 を指定すると、1 次元配列 [20, 28, 36]
が生成されます。
3D 配列をすべてのディメンションにわたって減算すると、スカラー 84
が生成されます。
可変引数 Reduce
N > 1
の場合、Reduce 関数はすべての入力に同時に適用されるため、やや複雑になります。オペランドは、次の順序で計算に渡されます。
- 最初のオペランドの実行された減算値
- ...
- N 番目のオペランドの縮小された値を実行中
- 最初のオペランドの入力値
- ...
- N 番目のオペランドの入力値
たとえば、次の減算関数は、1 次元配列の最大値と argmax を並列で計算するために使用できます。
f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
if value >= max:
return (value, index)
else:
return (max, argmax)
1 次元入力配列 V = Float[N], K = Int[N]
と初期値 I_V = Float, I_K = Int
の場合、唯一の入力次元で削減した結果 f_(N-1)
は、次の再帰適用と同等です。
f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
この減算を値の配列と連続インデックスの配列(iota)に適用すると、配列が同時に反復処理され、最大値と一致するインデックスを含むタプルが返されます。
ReducePrecision
XlaBuilder::ReducePrecision
もご覧ください。
浮動小数点値を低精度形式(IEEE-FP16 など)に変換し、元の形式に戻す効果をモデル化します。低精度形式の指数と小数点以下ビットの数は任意で指定できますが、すべてのビットサイズがすべてのハードウェア実装でサポートされているわけではありません。
ReducePrecision(operand, mantissa_bits, exponent_bits)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
浮動小数点型 T の配列。 |
exponent_bits |
int32 |
低精度形式の指数ビット数 |
mantissa_bits |
int32 |
低精度形式のマンシサビット数 |
結果は T
型の配列です。入力値は、指定された小数点以下ビット数で表せる値に切り上げられます(「偶数に切り上げ」セマンティクスを使用)。指数ビット数で指定された範囲を超える値は、正または負の無限大にクランプされます。NaN
値は保持されますが、正規の NaN
値に変換される場合があります。
精度が低い形式には、少なくとも 1 つの指数ビットが必要です(ゼロの仮数を持つ値と無限大の値を区別するため)。また、仮数ビットの数は正の整数にする必要があります。指数または小数点以下ビットの数が、型 T
の対応する値を超える場合があります。その場合、変換の対応する部分は単に no-op になります。
ReduceScatter
XlaBuilder::ReduceScatter
もご覧ください。
ReduceScatter は、all-reduce を効果的に実行し、結果を scatter_dimension
に沿って shard_count
ブロックに分割して結果を散布する集団演算です。レプリカ グループ内のレプリカ i
は ith
シャードを受信します。
ReduceScatter(operand, computation, scatter_dim, shard_count,
replica_group_ids, channel_id)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
レプリカ間で減算する配列または配列の空でないタプル。 |
computation |
XlaComputation |
削減の計算 |
scatter_dimension |
int64 |
散布するディメンション。 |
shard_count |
int64 |
scatter_dimension を分割するブロックの数 |
replica_groups |
int64 のベクトルのベクトル |
削減を実行するグループ |
channel_id |
省略可 int64 |
オプションのモジュール間通信用のチャンネル ID |
operand
が配列のチュープルの場合、チュープルの各要素に対して reduce-scatter が実行されます。replica_groups
は、縮小が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID はReplicaId
を使用して取得できます)。各グループのレプリカの順序によって、all-reduce の結果が分散される順序が決まります。replica_groups
は空であるか(この場合、すべてのレプリカが 1 つのグループに属します)、レプリカの数と同じ数の要素を含める必要があります。レプリカ グループが複数ある場合は、すべて同じサイズにする必要があります。たとえば、replica_groups = {0, 2}, {1, 3}
はレプリカ0
と2
、1
と3
の間で減算を実行し、結果を分散します。shard_count
は、各レプリカ グループのサイズです。これは、replica_groups
が空の場合に必要です。replica_groups
が空でない場合、shard_count
は各レプリカ グループのサイズと等しくする必要があります。channel_id
はモジュール間通信に使用されます。同じchannel_id
を持つreduce-scatter
オペレーションのみが相互に通信できます。
出力シェイプは、scatter_dimension
が shard_count
倍小さくなった入力シェイプです。たとえば、2 つのレプリカがあり、2 つのレプリカでオペランドの値がそれぞれ [1.0, 2.25]
と [3.0, 5.25]
の場合、scatter_dim
が 0
であるこのオペレーションの出力値は、最初のレプリカでは [4.0]
、2 番目のレプリカでは [7.5]
になります。
ReduceWindow
XlaBuilder::ReduceWindow
もご覧ください。
N 個の多次元配列のシーケンスの各ウィンドウ内のすべての要素にリダクション関数を適用し、N 個の多次元配列の単一またはタプルを出力として生成します。各出力配列には、ウィンドウの有効な位置の数と同じ数の要素があります。プーリング レイヤは ReduceWindow
として表すことができます。Reduce
と同様に、適用された computation
には常に左側の init_values
が渡されます。
ReduceWindow(operands..., init_values..., computation, window_dimensions,
window_strides, padding)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
N XlaOps |
T_0,..., T_{N-1} 型の N 個の多次元配列のシーケンス。各配列は、ウィンドウが配置されるベース領域を表します。 |
init_values |
N XlaOps |
減算の N 個の開始値(N 個のオペランドごとに 1 つ)。詳しくは、削減をご覧ください。 |
computation |
XlaComputation |
すべての入力オペランドの各ウィンドウ内の要素に適用する、T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 型のリダクション関数。 |
window_dimensions |
ArraySlice<int64> |
ウィンドウ ディメンション値の整数の配列 |
window_strides |
ArraySlice<int64> |
ウィンドウ ストライド値の整数の配列 |
base_dilations |
ArraySlice<int64> |
ベース拡張値の整数の配列 |
window_dilations |
ArraySlice<int64> |
ウィンドウ拡張値の整数の配列 |
padding |
Padding |
ウィンドウのパディング タイプ(Padding::kSame: ストライドが 1 の場合に入力と同じ出力シェイプになるようにパディングします。Padding::kValid: パディングを使用し、ウィンドウが適合しなくなったら「停止」します)。 |
ここで
- N は 1 以上にする必要があります。
- すべての入力配列のディメンションは同じである必要があります。
N = 1
の場合、Collate(T)
はT
です。N > 1
の場合、Collate(T_0, ..., T_{N-1})
は(T0,...T{N-1})
型のN
要素のタプルです。
以下のコードと図は、ReduceWindow
の使用例を示しています。入力はサイズ [4x6] のマトリックスで、window_dimensions と window_stride_dimensions の両方が [2x3] です。
// Create a computation for the reduction (maximum).
XlaComputation max;
{
XlaBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
max = builder.Build().value();
}
// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
input,
/*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
*max,
/*window_dimensions=*/{2, 3},
/*window_stride_dimensions=*/{2, 3},
Padding::kValid);
ディメンションのストライドが 1 の場合、ディメンション内のウィンドウの位置は、隣接するウィンドウから 1 要素離れていることを指定します。ウィンドウが重ならないように指定するには、window_stride_dimensions を window_dimensions にする必要があります。次の図は、2 つの異なるストライド値の使用を示しています。パディングは入力の各ディメンションに適用され、計算は、入力がパディング後のディメンションで入力された場合と同じです。
複雑なパディングの例として、入力配列 [10000, 1000, 100, 10, 1]
に対して、ディメンション 3
とストライド 2
で、減算ウィンドウの最小値(初期値は MAX_FLOAT
)を計算することを検討してください。パディング kValid
では、2 つの有効なウィンドウ([10000, 1000, 100]
と [100, 10, 1]
)の最小値が計算され、結果は [100, 1]
になります。パディング kSame
は、両側に初期要素を追加して [MAX_VALUE, 10000, 1000, 100, 10, 1,
MAX_VALUE]
を取得することで、まず配列をパディングし、減算ウィンドウ後の形状がストライド 1 の入力と同じになるようにします。パディングされた配列に対して reduce-window を実行すると、3 つのウィンドウ([MAX_VALUE, 10000, 1000]
、[1000, 100, 10]
、[10, 1, MAX_VALUE]
)で動作し、[1000, 10, 1]
が生成されます。
リダクション関数の評価順序は任意で、非決定的になる可能性があります。したがって、減算関数は再結合に過度に敏感であってはなりません。詳細については、Reduce
のコンテキストでの結合性に関する説明をご覧ください。
ReplicaId
XlaBuilder::ReplicaId
もご覧ください。
レプリカの一意の ID(U32 スカラー)を返します。
ReplicaId()
各レプリカの一意の ID は、間隔 [0, N)
の符号なし整数です。ここで、N
はレプリカの数です。すべてのレプリカが同じプログラムを実行しているため、プログラム内の ReplicaId()
呼び出しは、レプリカごとに異なる値を返します。
Reshape
XlaBuilder::Reshape
と Collapse
オペレーションもご覧ください。
配列のディメンションを新しい構成に変更します。
Reshape(operand, new_sizes)
Reshape(operand, dimensions, new_sizes)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
型 T の配列 |
dimensions |
int64 ベクトル |
ディメンションが折りたたまれた順序 |
new_sizes |
int64 ベクトル |
新しいディメンションのサイズのベクトル |
概念的には、reshape はまず配列をデータ値の 1 次元ベクトルにフラット化し、次にこのベクトルを新しいシェイプに絞り込みます。入力引数は、型 T の任意の配列、ディメンション インデックスのコンパイル時定数ベクトル、結果のディメンション サイズのコンパイル時定数ベクトルです。dimension
ベクトルの値(指定されている場合)は、T のすべてのディメンションの並べ替えである必要があります。指定されていない場合のデフォルトは {0, ..., rank - 1}
です。dimensions
のディメンションの順序は、ループ ネストで最も変化の少ないディメンション(最も大きい)から最も変化の大きいディメンション(最も小さい)です。これにより、入力配列が 1 つのディメンションに圧縮されます。new_sizes
ベクトルにより、出力配列のサイズが決まります。new_sizes
のインデックス 0 の値はディメンション 0 のサイズ、インデックス 1 の値はディメンション 1 のサイズです。new_size
ディメンションの積は、オペランドのディメンション サイズの積と等しくする必要があります。圧縮された配列を new_sizes
で定義された多次元配列に絞り込む場合、new_sizes
内のディメンションは、変化が最も遅い(最も大きなディメンション)から変化が最も速い(最も小さなディメンション)の順に並べ替えられます。
たとえば、v を 24 個の要素からなる配列とします。
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
{20, 21, 22}, {25, 26, 27},
{30, 31, 32}, {35, 36, 37},
{40, 41, 42}, {45, 46, 47} };
Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
{31, 41, 12}, {22, 32, 42},
{15, 25, 35}, {45, 16, 26},
{36, 46, 17}, {27, 37, 47} };
let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
{11, 21}, {31, 41},
{12, 22}, {32, 42} },
{ {15, 25}, {35, 45},
{16, 26}, {36, 46},
{17, 27}, {37, 47} } };
特別なケースとして、reshape は単一要素の配列をスカラーに変換したり、その逆を行ったりできます。次に例を示します。
Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };
リバース(逆)
XlaBuilder::Rev
もご覧ください。
Rev(operand, dimensions)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
T 型の配列 |
dimensions |
ArraySlice<int64> |
ディメンションを逆にして |
指定された dimensions
に沿って operand
配列内の要素の順序を逆にして、同じ形状の出力配列を生成します。多次元インデックスの演算子配列の各要素は、変換されたインデックスの出力配列に格納されます。多次元インデックスは、反転する各ディメンションのインデックスを反転することで変換されます(サイズ N のディメンションが反転ディメンションの 1 つである場合、そのインデックス i は N - 1 - i に変換されます)。
Rev
演算の 1 つの用途は、ニューラル ネットワークの勾配計算中に、2 つのウィンドウ ディメンションに沿って畳み込み重み配列を反転することです。
RngNormal
XlaBuilder::RngNormal
もご覧ください。
\(N(\mu, \sigma)\) 正規分布に従って生成された乱数を使用して、指定されたシェイプの出力を構築します。パラメータ \(\mu\) と \(\sigma\)、出力シェイプは浮動小数点要素型である必要があります。さらに、パラメータにはスカラー値を指定する必要があります。
RngNormal(mu, sigma, shape)
引数 | タイプ | セマンティクス |
---|---|---|
mu |
XlaOp |
生成された数値の平均を指定する T 型のスカラー |
sigma |
XlaOp |
生成された値の標準偏差を指定する T 型のスカラー |
shape |
Shape |
T 型の出力形状 |
RngUniform
XlaBuilder::RngUniform
もご覧ください。
区間 \([a,b)\)で均一分布に従って生成された乱数を使用して、指定された形状の出力を構築します。パラメータと出力要素の型は、ブール型、整数型、または浮動小数点型である必要があり、型は一貫している必要があります。現在、CPU バックエンドと GPU バックエンドは F64、F32、F16、BF16、S64、U64、S32、U32 のみをサポートしています。さらに、パラメータはスカラー値である必要があります。 \(b <= a\) の場合、結果は実装定義です。
RngUniform(a, b, shape)
引数 | タイプ | セマンティクス |
---|---|---|
a |
XlaOp |
区間の下限を指定する型 T のスカラー |
b |
XlaOp |
区間の上限を指定する T 型のスカラー |
shape |
Shape |
型 T の出力シェイプ |
RngBitGenerator
指定されたアルゴリズム(またはバックエンドのデフォルト)を使用して、均一なランダムビットで埋められた特定のシェイプで出力を生成し、更新された状態(初期状態と同じシェイプ)と生成されたランダムデータを返します。
初期状態は、現在の乱数生成の初期状態です。必要な形状と有効な値は、使用するアルゴリズムによって異なります。
出力は初期状態の確定的関数であることが保証されますが、バックエンドと異なるコンパイラ バージョン間で確定的であることは保証されません。
RngBitGenerator(algorithm, key, shape)
引数 | タイプ | セマンティクス |
---|---|---|
algorithm |
RandomAlgorithm |
使用する PRNG アルゴリズム。 |
initial_state |
XlaOp |
PRNG アルゴリズムの初期状態。 |
shape |
Shape |
生成されたデータの出力シェイプ。 |
algorithm
に指定できる値:
rng_default
: バックエンド固有のシェイプ要件を持つバックエンド固有のアルゴリズム。rng_three_fry
: ThreeFry カウンタベースの PRNG アルゴリズム。initial_state
シェイプは、任意の値を持つu64[2]
です。Salmon et al. SC 2011. 並列乱数: 1、2、3 で簡単に生成。rng_philox
: 乱数を並行して生成する Philox アルゴリズム。initial_state
シェイプは、任意の値を持つu64[3]
です。Salmon et al. SC 2011. 並列乱数: 1、2、3 で簡単に生成。
散布
XLA スキャッター オペレーションは、入力配列 operands
の値である一連の結果を生成します。複数のスライス(scatter_indices
で指定されたインデックス)は、update_computation
を使用して updates
の値のシーケンスで更新されます。
XlaBuilder::Scatter
もご覧ください。
scatter(operands..., scatter_indices, updates..., update_computation,
index_vector_dim, update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
N XlaOp のシーケンス |
分散される T_0, ..., T_N 型の N 個の配列。 |
scatter_indices |
XlaOp |
分散するスライスの開始インデックスを含む配列。 |
updates |
N 個の XlaOp のシーケンス |
T_0, ..., T_N 型の N 配列。updates[i] には、operands[i] の散乱に使用する値が含まれています。 |
update_computation |
XlaComputation |
入力配列内の既存の値と、散布時の更新を組み合わせるために使用される計算。この計算の型は T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) にする必要があります。 |
index_vector_dim |
int64 |
開始インデックスを含む scatter_indices のディメンション。 |
update_window_dims |
ArraySlice<int64> |
ウィンドウのサイズである updates シェイプのディメンションのセット。 |
inserted_window_dims |
ArraySlice<int64> |
updates シェイプに挿入する必要があるウィンドウのサイズのセット。 |
scatter_dims_to_operand_dims |
ArraySlice<int64> |
スキャッター インデックスからオペランド インデックス空間へのディメンション マップ。この配列は、i から scatter_dims_to_operand_dims[i] へのマッピングとして解釈されます。1 対 1 で合計数を指定する必要があります。 |
indices_are_sorted |
bool |
呼び出し元によってインデックスが並べ替えられることが保証されているかどうか。 |
unique_indices |
bool |
呼び出し元によってインデックスが一意であることが保証されているかどうか。 |
ここで
- N は 1 以上にする必要があります。
operands
[0
], ...,operands
[N-1
] はすべて同じディメンションである必要があります。updates
[0
]、...、updates
[N-1
] はすべて同じディメンションである必要があります。N = 1
の場合、Collate(T)
はT
です。N > 1
の場合、Collate(T_0, ..., T_N)
はT
型のN
要素のタプルです。
index_vector_dim
が scatter_indices.rank
の場合、scatter_indices
には末尾に 1
ディメンションがあると暗黙的にみなされます。
ArraySlice<int64>
型の update_scatter_dims
を、update_window_dims
にない updates
シェイプのディメンションのセットとして、昇順のものとして定義します。
scatter の引数は、次の制約に従う必要があります。
各
updates
配列のランクはupdate_window_dims.size + scatter_indices.rank - 1
にする必要があります。各
updates
配列内のディメンションi
の境界は、次のようにする必要があります。i
がupdate_window_dims
に存在する(k
でupdate_window_dims
[k
] に等しい)場合、updates
のディメンションi
の境界は、inserted_window_dims
を考慮した後のoperand
の対応する境界(adjusted_window_bounds
[k
])を超えてはなりません。ここで、adjusted_window_bounds
にはoperand
の境界が含まれ、インデックスinserted_window_dims
の境界は取り除かれます。i
がupdate_scatter_dims
に存在する(k
でupdate_scatter_dims
[k
] に等しい)場合、updates
のディメンションi
の境界は、index_vector_dim
をスキップして、scatter_indices
の対応する境界と等しくする必要があります(k
<index_vector_dim
の場合はscatter_indices.shape.dims
[k
]、それ以外の場合はscatter_indices.shape.dims
[k+1
])。
update_window_dims
は昇順で、ディメンション番号が重複しておらず、[0, updates.rank)
の範囲内である必要があります。inserted_window_dims
は昇順で、ディメンション番号の繰り返しを持たず、[0, operand.rank)
の範囲内にある必要があります。operand.rank
は、update_window_dims.size
とinserted_window_dims.size
の合計と等しくする必要があります。scatter_dims_to_operand_dims.size
はscatter_indices.shape.dims
[index_vector_dim
] と等しくなければならず、その値は[0, operand.rank)
の範囲内になければなりません。
各 updates
配列内の特定のインデックス U
について、この更新を適用する対応する operands
配列内の対応するインデックス I
は、次のように計算されます。
G
= {update_scatter_dims
のk
のU
[k
] } とします。G
を使用して、scatter_indices
配列のインデックス ベクトルS
を検索します。S
[i
] =scatter_indices
[Combine(G
,i
)] となるようにします。ここで、Combine(A, b) は、A のindex_vector_dim
位置に b を挿入します。scatter_dims_to_operand_dims
マップでS
を分散させて、S
を使用してoperand
にインデックスS
in
を作成します。よりフォーマルな表現:S
in
[scatter_dims_to_operand_dims
[k
]] =S
[k
](k
<scatter_dims_to_operand_dims.size
の場合)。S
in
[_
] =0
(それ以外の場合)。
inserted_window_dims
に従ってU
のupdate_window_dims
でインデックスを分散させることで、各operands
配列にインデックスW
in
を作成します。よりフォーマルな表現:k
がupdate_window_dims
にある場合、W
in
[window_dims_to_operand_dims
(k
)] =U
[k
]。ここで、window_dims_to_operand_dims
は、ドメイン [0
、update_window_dims.size
] と範囲 [0
、operand.rank
] \inserted_window_dims
の単調関数です。(たとえば、update_window_dims.size
が4
、operand.rank
が6
、inserted_window_dims
が {0
、2
} の場合、window_dims_to_operand_dims
は {0
→1
、1
→3
、2
→4
、3
→5
} です)。W
in
[_
] =0
(それ以外の場合)。
I
はW
in
+S
in
です。ここで、+ は要素単位の加算です。
要約すると、散布オペレーションは次のように定義できます。
output
をoperands
で初期化します(すべてのインデックスJ
、operands
[J
] 配列内のすべてのインデックスO
について)。
output
[J
][O
] =operands
[J
][O
]updates
[J
] 配列内のすべてのインデックスU
と、operand
[J
] 配列内の対応するインデックスO
について、O
がoutput
の有効なインデックスである場合:
(output
[0
][O
], ...,output
[N-1
][O
]) =update_computation
(output
[0
][O
], ..., ,output
[N-1
][O
],updates
[0
][U
], ...,updates
[N-1
][U
])
更新が適用される順序は非決定的です。そのため、updates
内の複数のインデックスが operands
内の同じインデックスを参照している場合、output
内の対応する値は非決定的になります。
update_computation
に渡される最初のパラメータは常に output
配列の現在の値で、2 番目のパラメータは常に updates
配列の値です。これは、特に update_computation
が交換可能でないケースで重要です。
indices_are_sorted
が true に設定されている場合、XLA は scatter_indices
がユーザーによって並べ替えられていると想定できます(scatter_dims_to_operand_dims
に従って値を分散した後、昇順)。そうでない場合、セマンティクスは実装で定義されます。
unique_indices
が true に設定されている場合、XLA は、分散されたすべての要素が一意であると想定できます。そのため、XLA ではアトミック以外のオペレーションを使用できます。unique_indices
が true に設定され、分散先のインデックスが一意でない場合、セマンティクスは実装定義です。
非公式には、scatter op は collect op の逆と見なすことができます。scatter op は、対応する collect op によって抽出された入力要素を更新します。
詳細な非公式の説明と例については、Gather
の「非公式の説明」セクションをご覧ください。
選択
XlaBuilder::Select
もご覧ください。
述語配列の値に基づいて、2 つの入力配列の要素から出力配列を作成します。
Select(pred, on_true, on_false)
引数 | タイプ | セマンティクス |
---|---|---|
pred |
XlaOp |
PRED 型の配列 |
on_true |
XlaOp |
型 T の配列 |
on_false |
XlaOp |
T 型の配列 |
配列 on_true
と on_false
は同じ形状にする必要があります。これは出力配列の形状でもあります。配列 pred
は、PRED
要素型で、on_true
と on_false
と同じ次元数にする必要があります。
pred
の各要素 P
について、P
の値が true
の場合は on_true
から、P
の値が false
の場合は on_false
から、出力配列の対応する要素が取得されます。ブロードキャストの制限付き形式として、pred
は PRED
型のスカラーにできます。この場合、出力配列は、pred
が true
の場合は on_true
から、pred
が false
の場合は on_false
からすべて取得されます。
スカラー以外の pred
を使用する場合の例:
let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
スカラー pred
を使用した例:
let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
タプル間の選択がサポートされています。この目的では、タプルはスカラー型と見なされます。on_true
と on_false
がタプル(同じシェイプである必要があります)の場合、pred
は PRED
型のスカラーである必要があります。
SelectAndScatter
XlaBuilder::SelectAndScatter
もご覧ください。
この演算は、まず operand
配列の ReduceWindow
を計算して各ウィンドウから要素を選択し、source
配列を選択された要素のインデックスに散布して、オペランド配列と同じ形状の出力配列を作成する複合演算と見なすことができます。バイナリ select
関数は、各ウィンドウに適用して各ウィンドウから要素を選択するために使用されます。この関数は、最初のパラメータのインデックス ベクトルが 2 番目のパラメータのインデックス ベクトルよりも辞書順で小さいプロパティで呼び出されます。select
関数は、最初のパラメータが選択された場合は true
を返し、2 番目のパラメータが選択された場合は false
を返します。また、選択された要素が特定のウィンドウで走査される要素の順序に依存しないように、関数は推移性を保持する必要があります(つまり、select(a, b)
と select(b, c)
が true
であれば、select(a, c)
も true
である必要があります)。
関数 scatter
は、出力配列内の選択した各インデックスに適用されます。次の 2 つのスカラー パラメータを取ります。
- 出力配列で選択されたインデックスの現在の値
- 選択したインデックスに適用される
source
の散布値
2 つのパラメータを結合し、出力配列の選択したインデックスの値を更新するために使用されるスカラー値を返します。初期状態では、出力配列のすべてのインデックスが init_value
に設定されています。
出力配列は operand
配列と同じ形状になり、source
配列は operand
配列に ReduceWindow
オペレーションを適用した結果と同じ形状にする必要があります。SelectAndScatter
は、ニューラル ネットワークのプーリング層の勾配値を逆伝播するために使用できます。
SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
ウィンドウがスライドする T 型の配列 |
select |
XlaComputation |
T, T -> PRED 型のバイナリ計算。各ウィンドウ内のすべての要素に適用されます。最初のパラメータが選択されている場合は true を返し、2 番目のパラメータが選択されている場合は false を返します。 |
window_dimensions |
ArraySlice<int64> |
ウィンドウ ディメンション値の整数の配列 |
window_strides |
ArraySlice<int64> |
ウィンドウのストライド値を示す整数の配列 |
padding |
Padding |
ウィンドウのパディング タイプ(Padding::kSame または Padding::kValid) |
source |
XlaOp |
散布する値を含む T 型の配列 |
init_value |
XlaOp |
出力配列の初期値の型 T のスカラー値 |
scatter |
XlaComputation |
T, T -> T 型のバイナリ計算。各散布図のソース要素を宛先要素に適用します。 |
次の図は、SelectAndScatter
の使用例を示しています。select
関数は、パラメータの最大値を計算します。以下の図(2)のようにウィンドウが重複している場合、operand
配列のインデックスが異なるウィンドウによって複数回選択される可能性があります。この図では、値 9 の要素が上のウィンドウ(青と赤)の両方によって選択され、バイナリ加算 scatter
関数によって値 8(2 + 6)の出力要素が生成されます。
scatter
関数の評価順序は任意であり、非決定的である場合もあります。したがって、scatter
関数は再アソシエーションに過度に敏感であってはなりません。詳細については、Reduce
のコンテキストでの結合性に関する説明をご覧ください。
送信
XlaBuilder::Send
もご覧ください。
Send(operand, channel_handle)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
送信するデータ(型 T の配列) |
channel_handle |
ChannelHandle |
送受信ペアごとの一意の識別子 |
指定されたオペランド データを、同じチャネル ハンドルを共有する別の計算の Recv
命令に送信します。データを返しません。
Recv
オペレーションと同様に、Send
オペレーションのクライアント API は同期通信を表し、非同期のデータ転送を可能にするために内部的に 2 つの HLO 命令(Send
と SendDone
)に分解されます。HloInstruction::CreateSend
と HloInstruction::CreateSendDone
もご覧ください。
Send(HloInstruction operand, int64 channel_id)
同じチャンネル ID を持つ Recv
命令によって割り振られたリソースへのオペランドの非同期転送を開始します。コンテキストを返します。これは、次の SendDone
命令で使用され、データ転送の完了を待機します。コンテキストは {オペランド(シェイプ)、リクエスト ID(U32)} のタプルであり、SendDone
命令でのみ使用できます。
SendDone(HloInstruction context)
Send
命令によって作成されたコンテキストを受け取り、データ転送が完了するまで待機します。この命令はデータを返しません。
チャンネルに関する手順のスケジュール設定
各チャンネルの 4 つの命令(Recv
、RecvDone
、Send
、SendDone
)の実行順序は次のとおりです。
Recv
はSend
の前に発生します。Send
はRecvDone
より前に発生しますRecv
はRecvDone
より前に発生しますSend
はSendDone
より前に発生します
バックエンド コンパイラが、チャネル命令を介して通信する計算ごとに線形スケジュールを生成する場合、計算間でサイクルが発生しないようにする必要があります。たとえば、次のスケジュールはデッドロックにつながります。
スライス
XlaBuilder::Slice
もご覧ください。
スライスでは、入力配列からサブ配列を抽出します。サブ配列は入力と同じランクで、入力配列内の境界ボックス内の値が含まれます。境界ボックスのディメンションとインデックスは、スライス オペレーションの引数として指定されます。
Slice(operand, start_indices, limit_indices, strides)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
T 型の N 次元配列 |
start_indices |
ArraySlice<int64> |
各ディメンションのスライスの開始インデックスを含む N 個の整数のリスト。0 以上の値を指定してください。 |
limit_indices |
ArraySlice<int64> |
各ディメンションのスライスの終了インデックス(除く)を含む N 個の整数のリスト。各値は、ディメンションのそれぞれの start_indices 値以上で、ディメンションのサイズ以下である必要があります。 |
strides |
ArraySlice<int64> |
スライスの入力ストライドを決定する N 個の整数のリスト。スライスは、ディメンション d 内のすべての strides[d] 要素を選択します。 |
1 次元の例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
{2.0, 3.0}
2 次元の例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
Slice(b, {2, 1}, {4, 3}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
並べ替え
XlaBuilder::Sort
もご覧ください。
Sort(operands, comparator, dimension, is_stable)
引数 | タイプ | セマンティクス |
---|---|---|
operands |
ArraySlice<XlaOp> |
並べ替えるオペランド。 |
comparator |
XlaComputation |
使用する比較演算。 |
dimension |
int64 |
並べ替えるディメンション。 |
is_stable |
bool |
安定した並べ替えを使用するかどうか。 |
オペランドを 1 つだけ指定した場合:
オペランドがランク 1 テンソル(配列)の場合、結果は並べ替えられた配列になります。配列を昇順に並べ替える場合は、比較演算子で小なり比較を実行する必要があります。正式には、配列が並べ替えられた後、
i < j
がcomparator(value[i], value[j]) = comparator(value[j], value[i]) = false
またはcomparator(value[i], value[j]) = true
のすべてのインデックス位置i, j
に保持されます。オペランドの階数が大きい場合、オペランドは指定されたディメンションに沿って並べ替えられます。たとえば、ランク 2 テンソル(行列)の場合、ディメンション値が
0
であればすべての列が個別に並べ替えられ、ディメンション値が1
であれば各行が個別に並べ替えられます。ディメンション番号が指定されていない場合、デフォルトで最後のディメンションが選択されます。並べ替えられたディメンションには、ランク 1 の場合と同じ並べ替え順序が適用されます。
n > 1
オペランドが指定されている場合:
すべての
n
オペランドは、同じディメンションのテンサーである必要があります。テンソルの要素タイプは異なる場合がありますすべてのオペランドは個別ではなく、一緒に並べ替えられます。概念的には、オペランドはタプルとして扱われます。インデックス位置
i
とj
の各オペランドの要素を入れ替える必要があるかどうかを確認するときに、2 * n
スカラー パラメータを使用して比較演算子が呼び出されます。パラメータ2 * k
はk-th
オペランドの位置i
の値に対応し、パラメータ2 * k + 1
はk-th
オペランドの位置j
の値に対応します。通常、比較オペレーターはパラメータ2 * k
と2 * k + 1
を比較し、必要に応じて他のパラメータペアをタイブレークとして使用します。結果は、(上記のように指定されたディメンションに沿って)並べ替えられたオペランドで構成されるタプルです。タプルの
i-th
オペランドは、Sort のi-th
オペランドに対応します。
たとえば、3 つのオペランド operand0 = [3, 1]
、operand1 = [42, 50]
、operand2 = [-3.0, 1.1]
があり、比較演算子が operand0
の値のみを小なりで比較する場合、並べ替えの出力はタプル ([1, 3], [50, 42], [1.1, -3.0])
になります。
is_stable
が true に設定されている場合、並べ替えは安定することが保証されます。つまり、比較演算子によって等しいと見なされる要素がある場合、等しい値の相対順序は保持されます。2 つの要素 e1
と e2
は、comparator(e1, e2) = comparator(e2, e1) = false
の場合にのみ等しくなります。デフォルトでは、is_stable
は false に設定されています。
行 / 列の入れ替え
tf.reshape
オペレーションもご覧ください。
Transpose(operand)
引数 | タイプ | セマンティクス |
---|---|---|
operand |
XlaOp |
転置するオペランド。 |
permutation |
ArraySlice<int64> |
ディメンションを並べ替える方法。 |
オペランドのディメンションを指定された順序で並べ替えます(∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]
)。
これは、Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) と同じです。
TriangularSolve
XlaBuilder::TriangularSolve
もご覧ください。
前方置換または逆置換により、下または上の三角係数行列を持つ連立一次方程式を解く。先頭のディメンションに沿ってブロードキャストするこのルーティンは、a
と b
が指定された変数 x
について、行列システム op(a) * x =
b
または x * op(a) = b
のいずれかを解きます。ここで、op(a)
は op(a) = a
、op(a) = Transpose(a)
、または op(a) = Conj(Transpose(a))
です。
TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)
引数 | タイプ | セマンティクス |
---|---|---|
a |
XlaOp |
形状が [..., M, M] の複素数型または浮動小数点型の rank > 2 の配列。 |
b |
XlaOp |
left_side が true の場合 [..., M, K] 、それ以外の場合は [..., K, M] の形状を持つ、同じ型のランク 3 以上の配列。 |
left_side |
bool |
op(a) * x = b (true )または x * op(a) = b (false )形式のシステムを解くかどうかを示します。 |
lower |
bool |
a の上三角形と下三角形のどちらを使用するかを指定します。 |
unit_diagonal |
bool |
true の場合、a の対角要素は 1 とみなされ、アクセスされません。 |
transpose_a |
Transpose |
a をそのまま使用するか、転置するか、共役転置するかを指定します。 |
入力データは、lower
の値に応じて、a
の下部または上部の三角形からのみ読み取られます。他の三角形の値は無視されます。出力データは同じ三角形で返されます。他の三角形の値は実装定義であり、任意の値にすることができます。
a
と b
の階数が 2 より大きい場合、これらは行列のバッチとして扱われます。この場合、マイナーな 2 つのディメンション以外のすべてのディメンションがバッチ ディメンションになります。a
と b
のバッチ ディメンションは同じである必要があります。
タプル
XlaBuilder::Tuple
もご覧ください。
可変の数のデータハンドルを含むタプル。各データハンドルには独自のシェイプがあります。
これは C++ の std::tuple
に似ています。コンセプト的には次のようになります。
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
タプルは、GetTupleElement
オペレーションを介してデコンストラクト(アクセス)できます。
しばらく
XlaBuilder::While
もご覧ください。
While(condition, body, init)
引数 | タイプ | セマンティクス |
---|---|---|
condition |
XlaComputation |
ループの終了条件を定義する T -> PRED 型の XlaComputation。 |
body |
XlaComputation |
ループの本文を定義する T -> T 型の XlaComputation。 |
init |
T |
condition と body のパラメータの初期値。 |
condition
が失敗するまで body
を順番に実行します。これは他の多くの言語での一般的な while ループと似ていますが、下記の違いと制限事項があります。
While
ノードは、body
の最後の実行結果であるT
型の値を返します。- 型
T
の形状は静的に決定され、すべての反復処理で同じである必要があります。
計算の T パラメータは、最初のイテレーションで init
値で初期化され、後続の各イテレーションで body
の新しい結果に自動的に更新されます。
While
ノードの主なユースケースの 1 つは、ニューラル ネットワークでトレーニングの繰り返し実行を実装することです。計算を表すグラフとともに、簡素化された疑似コードを以下に示します。コードは while_test.cc
にあります。この例の型 T
は、反復回数用の int32
と累積用の vector[10]
で構成される Tuple
です。1,000 回の反復処理で、ループは定数ベクトルをアキュムレータに追加し続けます。
// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
iteration = result(0) + 1;
new_vector = result(1) + constant_vector[10];
result = {iteration, new_vector};
}