オペレーション セマンティクス

以下では、XlaBuilder インターフェースで定義されているオペレーションのセマンティクスについて説明します。通常、これらのオペレーションは、xla_data.proto の RPC インターフェースで定義されたオペレーションに 1 対 1 でマッピングされます。

命名規則に関する注記: XLA が扱う一般化データ型は、ある一様な型(32 ビット浮動小数点数など)の要素を保持する N 次元配列です。このドキュメントでは、任意の次元の配列を表すために「配列」という用語を使用します。便宜上、特殊なケースにはより具体的でわかりやすい名前が付けられています。たとえば、ベクトルは 1 次元配列、行列は 2 次元配列です。

AfterAll

XlaBuilder::AfterAll もご覧ください。

AfterAll は可変数のトークンを受け取って 1 つのトークンを生成します。トークンはプリミティブ タイプであり、副作用のあるオペレーション間でスレッド化されて順序付けを適用できます。AfterAll は、set オペレーションの後にオペレーションを並べ替えるためのトークンの結合として使用できます。

AfterAll(operands)

引数 タイプ セマンティクス
operands XlaOp 可変長のトークン数

AllGather

XlaBuilder::AllGather もご覧ください。

レプリカ間で連結を実行します。

AllGather(operand, all_gather_dim, shard_count, replica_group_ids, channel_id)

引数 タイプ セマンティクス
operand XlaOp レプリカ間で連結する配列
all_gather_dim int64 連結ディメンション
replica_groups int64 のベクトルのベクトル 連結が実行されるグループ
channel_id 省略可の int64 モジュール間通信用のオプションのチャンネル ID
  • replica_groups は、連結が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID は ReplicaId を使用して取得できます)。各グループのレプリカの順序によって、結果に入力が配置される順序が決まります。replica_groups は空であるか(この場合、すべてのレプリカが 0N - 1 の順に単一のグループに属します)、レプリカの数と同じ数の要素を含める必要があります。たとえば、replica_groups = {0, 2}, {1, 3} はレプリカ 0213 の連結を行います。
  • shard_count は、各レプリカ グループのサイズです。これは、replica_groups が空の場合に必要です。
  • channel_id はモジュール間通信に使用されます。同じ channel_id を持つ all-gather オペレーションのみが相互に通信できます。

出力シェイプは、all_gather_dimshard_count 倍に拡大された入力シェイプです。たとえば、2 つのレプリカがあり、2 つのレプリカでオペランドの値がそれぞれ [1.0, 2.5][3.0, 5.25] の場合、all_gather_dim0 であるこのオペレーションの出力値は、両方のレプリカで [1.0, 2.5, 3.0, 5.25] になります。

AllReduce

XlaBuilder::AllReduce もご覧ください。

レプリカ間でカスタム計算を実行します。

AllReduce(operand, computation, replica_group_ids, channel_id)

引数 タイプ セマンティクス
operand XlaOp レプリカ間で減算する配列または配列の空でないタプル
computation XlaComputation 削減の計算
replica_groups int64 のベクトルのベクトル 削減を実行するグループ
channel_id 省略可の int64 モジュール間通信用のオプションのチャンネル ID
  • operand が配列のタプルの場合、タプルの各要素に対してオール リデュークが実行されます。
  • replica_groups は、削減が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID は ReplicaId を使用して取得できます)。replica_groups は空であるか(この場合、すべてのレプリカが 1 つのグループに属します)、レプリカの数と同じ数の要素を含める必要があります。たとえば、replica_groups = {0, 2}, {1, 3} はレプリカ 0213 の間で減算を行います。
  • channel_id はモジュール間通信に使用されます。同じ channel_id を持つ all-reduce オペレーションのみが相互に通信できます。

出力の形状は入力の形状と同じです。たとえば、2 つのレプリカがあり、2 つのレプリカでオペランドの値がそれぞれ [1.0, 2.5][3.0, 5.25] の場合、このオペレーションと合計計算の出力値は両方のレプリカで [4.0, 7.75] になります。入力がタプルの場合、出力もタプルになります。

AllReduce の結果を計算するには、各レプリカから 1 つの入力が必要です。したがって、あるレプリカが別のレプリカよりも AllReduce ノードを実行する回数が多い場合、前者のレプリカは永遠に待機します。レプリカはすべて同じプログラムを実行しているため、このような状態が発生する方法は多くありませんが、while ループの条件がインフィードのデータに依存し、インフィードされたデータによって、一方のレプリカでもう一方のレプリカよりも多くの回数 while ループが反復される場合は発生する可能性があります。

AllToAll

XlaBuilder::AllToAll もご覧ください。

AllToAll は、すべてのコアからすべてのコアにデータを送信する集約オペレーションです。次の 2 つのフェーズがあります。

  1. 散布フェーズ。各コアで、オペランドは split_dimensions に沿って split_count 個のブロックに分割され、ブロックはすべてのコアに分散されます(たとえば、i 番目のブロックは i 番目のコアに送信されます)。
  2. 収集フェーズ。各コアは、受信したブロックを concat_dimension に沿って連結します。

参加するコアは、次の方法で構成できます。

  • replica_groups: 各 ReplicaGroup には、計算に参加するレプリカ ID のリストが含まれています(現在のレプリカのレプリカ ID は ReplicaId を使用して取得できます)。AllToAll は、指定された順序でサブグループ内に適用されます。たとえば、replica_groups = { {1,2,3}, {4,5,0} } は、AllToAll がレプリカ {1, 2, 3} 内と収集フェーズで適用され、受信したブロックが 1、2、3 の順序で連結されることを意味します。次に、レプリカ 4、5、0 内で別の AllToAll が適用され、連結順序も 4、5、0 になります。replica_groups が空の場合、すべてのレプリカは、出現順に連結されて 1 つのグループに属します。

前提条件:

  • split_dimension のオペランドのディメンション サイズは split_count で割り切れます。
  • オペランドの形状がタプルではありません。

AllToAll(operand, split_dimension, concat_dimension, split_count, replica_groups)

引数 タイプ セマンティクス
operand XlaOp n 次元の入力配列
split_dimension int64 演算対象が分割されるディメンションの名前を指定する [0, n) の範囲内の値
concat_dimension int64 分割ブロックを連結するディメンションの名前を指定する [0, n) の範囲内の値
split_count int64 このオペレーションに参加するコアの数。replica_groups が空の場合は、レプリカ数にする必要があります。それ以外の場合は、各グループのレプリカ数にする必要があります。
replica_groups ReplicaGroup ベクトル 各グループにはレプリカ ID のリストが含まれています。

Alltoall の例を次に示します。

XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);

この例では、Alltoall に参加しているコアは 4 つあります。各コアで、オペランドは次元 1 に沿って 4 つの部分に分割されるため、各部分の形状は f32[4,4] になります。4 つの部分はすべてのコアに分散されます。次に、各コアは、コア 0 ~ 4 の順序で、受信した部分をディメンション 0 に沿って連結します。したがって、各コアの出力は f32[16,4] の形状になります。

BatchNormGrad

アルゴリズムの詳細については、XlaBuilder::BatchNormGrad元のバッチ ノーマライゼーションの論文もご覧ください。

バッチ正規化の勾配を計算します。

BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)

引数 タイプ セマンティクス
operand XlaOp 正規化する n 次元配列(x)
scale XlaOp 1 次元配列(\(\gamma\))
mean XlaOp 1 次元配列(\(\mu\))
variance XlaOp 1 次元配列(\(\sigma^2\))
grad_output XlaOp BatchNormTraining(\(\nabla y\))に渡されるグラデーション
epsilon float イプシロン値(\(\epsilon\))
feature_index int64 operand の特徴量ディメンションのインデックス

このオペレーションは、特徴ディメンション内の各特徴(feature_indexoperand の特徴ディメンションのインデックス)について、他のすべてのディメンション全体で operandoffsetscale に関する勾配を計算します。feature_index は、operand の特徴ディメンションの有効なインデックスである必要があります。

3 つの勾配は、次の式で定義されます(4 次元配列を operand とし、特徴ディメンション インデックスを l、バッチサイズを m、空間サイズを wh とします)。

\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]

入力 meanvariance は、バッチ ディメンションと空間ディメンション全体のモーメント値を表します。

出力型は、3 つのハンドルのタプルです。

出力 タイプ セマンティクス
grad_operand XlaOp 入力 operand に対する勾配($\nabla x$)
grad_scale XlaOp 入力 scale に対する勾配($\nabla \gamma$)
grad_offset XlaOp 入力 offset に対する勾配($\nabla \beta$)

BatchNormInference

アルゴリズムの詳細については、XlaBuilder::BatchNormInference元のバッチ ノーマライゼーションの論文もご覧ください。

バッチと空間ディメンション全体で配列を正規化します。

BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)

引数 タイプ セマンティクス
operand XlaOp 正規化する N 次元配列
scale XlaOp 1 次元配列
offset XlaOp 1 次元配列
mean XlaOp 1 次元配列
variance XlaOp 1 次元配列
epsilon float イプシロン値
feature_index int64 operand の特徴量ディメンションのインデックス

このオペレーションは、特徴ディメンション内の各特徴(feature_indexoperand の特徴ディメンションのインデックス)について、他のすべてのディメンションの平均と分散を計算し、その平均と分散を使用して operand の各要素を正規化します。feature_index は、operand の特徴ディメンションの有効なインデックスである必要があります。

BatchNormInference は、バッチごとに meanvariance を計算せずに BatchNormTraining を呼び出す場合と同じです。代わりに、入力 meanvariance が推定値として使用されます。このオペレーションの目的は推論のレイテンシを短縮することです。そのため、BatchNormInference という名前が付けられています。

出力は、入力 operand と同じ形状の n 次元正規化配列です。

BatchNormTraining

アルゴリズムの詳細については、XlaBuilder::BatchNormTrainingthe original batch normalization paper もご覧ください。

バッチと空間ディメンション全体で配列を正規化します。

BatchNormTraining(operand, scale, offset, epsilon, feature_index)

引数 タイプ セマンティクス
operand XlaOp 正規化する n 次元配列(x)
scale XlaOp 1 次元配列(\(\gamma\))
offset XlaOp 1 次元配列(\(\beta\))
epsilon float イプシロン値(\(\epsilon\))
feature_index int64 operand の特徴量ディメンションのインデックス

このオペレーションは、特徴ディメンション内の各特徴(feature_indexoperand の特徴ディメンションのインデックス)について、他のすべてのディメンションの平均と分散を計算し、その平均と分散を使用して operand の各要素を正規化します。feature_index は、operand の特徴ディメンションの有効なインデックスである必要があります。

operand \(x\) 内の各バッチで、空間ディメンションのサイズが whm 要素を含むバッチについて、アルゴリズムは次のように動作します(operand が 4 次元配列であると仮定します)。

  • 特徴ディメンション内の各特徴 l のバッチ平均 \(\mu_l\) を計算します。 \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)

  • バッチ分散を計算します。 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$

  • 正規化、スケーリング、シフト: \(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)

イプシロン値(通常は小数値)は、ゼロ除算エラーを回避するために追加されます。

出力型は、3 つの XlaOp のタプルです。

出力 タイプ セマンティクス
output XlaOp 入力 operand(y)と同じ形状の n 次元配列
batch_mean XlaOp 1 次元配列(\(\mu\))
batch_var XlaOp 1 次元配列(\(\sigma^2\))

batch_meanbatch_var は、上記の式を使用してバッチ ディメンションと空間ディメンション全体で計算されるモーメントです。

BitcastConvertType

XlaBuilder::BitcastConvertType もご覧ください。

TensorFlow の tf.bitcast と同様に、データ シェイプからターゲット シェイプへの要素ごとのビットキャスト オペレーションを実行します。入力と出力のサイズが一致している必要があります。たとえば、s32 要素はビットキャスト ルーティンを介して f32 要素になり、1 つの s32 要素は 4 つの s8 要素になります。ビットキャスト キャストは低レベルのキャストとして実装されているため、浮動小数点表現が異なるマシンでは異なる結果が得られます。

BitcastConvertType(operand, new_element_type)

引数 タイプ セマンティクス
operand XlaOp サイズ D の型 T の配列
new_element_type PrimitiveType タイプ U

オペランドとターゲット シェイプのディメンションは、変換前後のプリミティブ サイズの比率によって変化する最後のディメンションを除き、一致している必要があります。

ソース要素と宛先要素の型はタプルにできません。

異なる幅のプリミティブ型へのビットキャスト変換

BitcastConvert HLO 命令は、出力要素型 T' のサイズが入力要素 T のサイズと等しくないケースをサポートしています。オペレーション全体は概念的にはビットキャストであり、基盤となるバイトは変更されないため、出力要素の形状を変更する必要があります。B = sizeof(T), B' = sizeof(T') の場合、次の 2 つのケースが考えられます。

まず、B > B' の場合、出力シェイプにサイズ B/B' の新しい最小ディメンションが追加されます。次に例を示します。

  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)

有効なスカラーの場合も、ルールは同じです。

  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)

B' > B の場合、この命令では入力シェイプの最後の論理ディメンションが B'/B に等しくする必要があります。このディメンションは変換中に破棄されます。

  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)

異なるビット幅間での変換は要素ごとではありません。

ブロードキャスト

XlaBuilder::Broadcast もご覧ください。

配列内のデータを複製して、配列にディメンションを追加します。

Broadcast(operand, broadcast_sizes)

引数 タイプ セマンティクス
operand XlaOp 複製する配列
broadcast_sizes ArraySlice<int64> 新しいディメンションのサイズ

新しいディメンションは左側に挿入されます。つまり、broadcast_sizes の値が {a0, ..., aN} で、オペランドの形状のディメンションが {b0, ..., bM} の場合、出力の形状のディメンションは {a0, ..., aN, b0, ..., bM} になります。

新しいディメンションは、オペランドのコピーをインデックスします。

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

たとえば、operand が値 2.0f のスカラー f32 で、broadcast_sizes{2, 3} の場合、結果は f32[2, 3] の形状を持つ配列になり、結果のすべての値は 2.0f になります。

BroadcastInDim

XlaBuilder::BroadcastInDim もご覧ください。

配列内のデータを複製して、配列のサイズとディメンション数を拡張します。

BroadcastInDim(operand, out_dim_size, broadcast_dimensions)

引数 タイプ セマンティクス
operand XlaOp 複製する配列
out_dim_size ArraySlice<int64> ターゲット シェイプのサイズ
broadcast_dimensions ArraySlice<int64> オペランド シェイプの各ディメンションがターゲット シェイプのどのディメンションに対応しているか

Broadcast に似ていますが、任意の場所にディメンションを追加し、サイズ 1 の既存のディメンションを拡張できます。

operand は、out_dim_size で記述されたシェイプにブロードキャストされます。broadcast_dimensions は、operand のディメンションをターゲット シェイプのディメンションにマッピングします。つまり、オペランドの i 番目のディメンションは、出力シェイプの broadcast_dimension[i] 番目のディメンションにマッピングされます。operand のディメンションは、サイズが 1 であるか、マッピング先の出力シェイプのディメンションと同じサイズである必要があります。残りのディメンションは、サイズ 1 のディメンションで埋められます。不完全なディメンション ブロードキャストでは、これらの不完全なディメンションに沿ってブロードキャストを行い、出力シェイプに到達します。セマンティクスについては、ブロードキャスト ページで詳しく説明しています。

電話

XlaBuilder::Call もご覧ください。

指定された引数を使用して計算を呼び出します。

Call(computation, args...)

引数 タイプ セマンティクス
computation XlaComputation 任意の型の N 個のパラメータを持つ T_0, T_1, ..., T_{N-1} -> S 型の計算
args N 個の XlaOp のシーケンス 任意の型の N 個の引数

args の arity と型は、computation のパラメータと一致する必要があります。args がなくても構いません。

CompositeCall

XlaBuilder::CompositeCall もご覧ください。

他の StableHLO オペレーションで構成された(コンポーズされた)オペレーションをカプセル化し、入力と composite_attributes を受け取って結果を生成します。op のセマンティクスは、分解属性によって実装されます。複合オペレーションは、プログラムのセマンティクスを変更せずに、分解に置き換えることができます。分解をインライン化しても同じオペレーション セマンティクスが提供されない場合は、custom_call を使用することをおすすめします。

バージョン フィールド(デフォルトは 0)は、コンポジットのセマンティクスが変更されたときを示すために使用されます。

このオペレーションは、属性 is_composite=true を持つ kCall として実装されます。decomposition フィールドは computation 属性で指定します。フロントエンド属性には、接頭辞 composite. が付いた残りの属性が格納されます。

CompositeCall オペレーションの例:

f32[] call(f32[] %cst), to_apply=%computation, is_composite=true,
frontend_attributes = {
  composite.name="foo.bar",
  composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},
  composite.version="1"
}

Call(computation, args..., name, composite_attributes, version)

引数 タイプ セマンティクス
inputs XlaOp 可変長の値の数
name string 複合体の名前
composite_attributes 省略可能な string 属性の文字列化された辞書(省略可)
decomposition XlaComputation 任意の型の N 個のパラメータを持つ T_0, T_1, ..., T_{N-1} -> S 型の計算
version int64 複合オペレーションのセマンティクスのバージョン アップデート

Cholesky

XlaBuilder::Cholesky もご覧ください。

対称(ヘルミチアン)正定値行列のバッチの Cholesky 分解を計算します。

Cholesky(a, lower)

引数 タイプ セマンティクス
a XlaOp 2 次元を超える複素型または浮動小数点型の配列。
lower bool a の上部または下部の三角形を使用するかどうか。

lowertrue の場合、$a = l となる下三角行列 l を計算します。l^T$ です。lowerfalse の場合、\(a = u^T . u\)となる上三角行列 u を計算します。

入力データは、lower の値に応じて、a の下部または上部の三角形からのみ読み取られます。他の三角形の値は無視されます。出力データは同じ三角形で返されます。他の三角形の値は実装定義であり、任意の値にすることができます。

a のディメンションが 2 つを超える場合、a は行列のバッチとして扱われます。この場合、マイナーな 2 つのディメンション以外のすべてのディメンションがバッチ ディメンションになります。

a が対称(ヘルミチアン)正定値でない場合、結果は実装定義です。

範囲制限

XlaBuilder::Clamp もご覧ください。

オペランドを最小値と最大値の範囲内にクランプします。

Clamp(min, operand, max)

引数 タイプ セマンティクス
min XlaOp 型 T の配列
operand XlaOp 型 T の配列
max XlaOp 型 T の配列

オペランドと最小値と最大値が指定された場合、オペランドが最小値と最大値の範囲内にある場合はオペランドを返します。範囲外の場合は、オペランドが範囲内より小さい場合は最小値を、範囲外より大きい場合は最大値を返します。つまり、clamp(a, x, b) = min(max(a, x), b) のようになります。

3 つの配列はすべて同じ形状である必要があります。または、制限付きのブロードキャストとして、min または maxT 型のスカラーにすることもできます。

スカラー minmax を使用した例:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

閉じる

XlaBuilder::Collapsetf.reshape オペレーションもご覧ください。

配列のディメンションを 1 つのディメンションに圧縮します。

Collapse(operand, dimensions)

引数 タイプ セマンティクス
operand XlaOp 型 T の配列
dimensions int64 ベクトル T のディメンションの順番に並んだ連続したサブセット。

圧縮では、オペランドのディメンションの指定されたサブセットが 1 つのディメンションに置き換えられます。入力引数は、T 型の任意の配列と、次元インデックスのビルド時定数ベクトルです。ディメンション インデックスは、T のディメンションの順序付き(ディメンション番号が低い順)連続サブセットである必要があります。したがって、{0, 1, 2}、{0, 1}、{1, 2} はすべて有効なディメンション セットですが、{1, 0} や {0, 2} は無効です。これらのディメンションは、置き換えるディメンションと同じ位置に、1 つの新しいディメンションに置き換えられます。新しいディメンションのサイズは、元のディメンションのサイズの積に等しくなります。dimensions の最小ディメンション番号は、これらのディメンションを圧縮するループ ネストで最も変化が遅いディメンション(最もメジャー)であり、最大ディメンション番号は最も変化が速いディメンション(最もマイナー)です。より一般的な集約順序が必要な場合は、tf.reshape 演算子をご覧ください。

たとえば、v を 24 要素の配列とします。

let v = f32[4x2x3] { { {10, 11, 12},  {15, 16, 17} },
{ {20, 21, 22},  {25, 26, 27} },
{ {30, 31, 32},  {35, 36, 37} },
{ {40, 41, 42},  {45, 46, 47} } };

// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};

// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };

// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };

CollectivePermute

XlaBuilder::CollectivePermute もご覧ください。

CollectivePermute は、レプリカ間でデータを送受信する集約オペレーションです。

CollectivePermute(operand, source_target_pairs)

引数 タイプ セマンティクス
operand XlaOp n 次元の入力配列
source_target_pairs <int64, int64> ベクトル (source_replica_id、target_replica_id)ペアのリスト。ペアごとに、オペランドはソースレプリカからターゲット レプリカに送信されます。

source_target_pair には次の制限があります。

  • 2 つのペアでターゲット レプリカ ID が同じで、ソース レプリカ ID が同じにならないようにします。
  • レプリカ ID がどのペアでもターゲットでない場合は、そのレプリカの出力は、入力と同じ形状の 0 で構成されるテンソルになります。

Concatenate

XlaBuilder::ConcatInDim もご覧ください。

連結は、複数の配列オペランドから配列を作成します。配列のディメンション数は、各入力配列オペランドと同じ数(ディメンション数が同じである必要があります)で、引数は指定された順序で格納されます。

Concatenate(operands..., dimension)

引数 タイプ セマンティクス
operands N 個の XlaOp のシーケンス 次元が [L0、L1、...] の型 T の N 個の配列。N >= 1 が必要です。
dimension int64 operands の間に連結するディメンションの名前を指定する [0, N) の範囲内の値。

dimension を除くすべてのディメンションは同じである必要があります。これは、XLA が「不規則な」配列をサポートしていないためです。また、0 次元の値は連結できません(連結するディメンションに名前を付けることができないためです)。

1 次元の例:

Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}

2 次元の例:

let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}

図:

条件文

XlaBuilder::Conditional もご覧ください。

Conditional(pred, true_operand, true_computation, false_operand, false_computation)

引数 タイプ セマンティクス
pred XlaOp PRED 型のスカラー
true_operand XlaOp 型の引数 \(T_0\)
true_computation XlaComputation 型の XlaComputation \(T_0 \to S\)
false_operand XlaOp 型の引数 \(T_1\)
false_computation XlaComputation 型の XlaComputation \(T_1 \to S\)

predtrue の場合は true_computation を、predfalse の場合は false_computation を実行し、結果を返します。

true_computation は \(T_0\) 型の単一の引数を取り、同じ型の true_operand で呼び出されます。false_computation は \(T_1\) 型の単一の引数を取り、同じ型の false_operand で呼び出されます。返される値の型(true_computationfalse_computation)は同じである必要があります。

pred の値に応じて、true_computationfalse_computation のいずれか 1 つだけが実行されます。

Conditional(branch_index, branch_computations, branch_operands)

引数 タイプ セマンティクス
branch_index XlaOp S32 型のスカラー
branch_computations N 個の XlaComputation のシーケンス 型の XlaComputation \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)
branch_operands N 個の XlaOp のシーケンス 型 \(T_0 , T_1 , ..., T_{N-1}\)の引数

branch_computations[branch_index] を実行し、結果を返します。branch_index が 0 未満または N 以上である S32 の場合、branch_computations[N-1] がデフォルトのブランチとして実行されます。

branch_computations[b] は \(T_b\) 型の単一の引数を取り、同じ型の branch_operands[b] で呼び出されます。各 branch_computations[b] の戻り値の型は同じである必要があります。

branch_index の値に応じて、branch_computations のいずれか 1 つだけが実行されます。

Conv(畳み込み)

XlaBuilder::Conv もご覧ください。

ConvWithGeneralPadding と同じですが、パディングは SAME または VALID という省略形で指定します。SAME パディングでは、入力(lhs)にゼロをパディングして、ストライドを考慮しない場合の出力が入力と同じシェイプになるようにします。VALID パディングは、パディングなしを意味します。

ConvWithGeneralPadding(畳み込み)

XlaBuilder::ConvWithGeneralPadding もご覧ください。

ニューラル ネットワークで使用される種類の畳み込みを計算します。ここで、畳み込みは、n 次元のベース領域を移動する n 次元のウィンドウと見なすことができます。ウィンドウの可能な位置ごとに計算が行われます。

引数 タイプ セマンティクス
lhs XlaOp 入力の(n+2)次元配列
rhs XlaOp カーネル重みの(n+2)次元配列
window_strides ArraySlice<int64> カーネル ストライドの n-d 配列
padding ArraySlice< pair<int64,int64>> (低、高)パディングの n-d 配列
lhs_dilation ArraySlice<int64> n-d lhs dilation factor array
rhs_dilation ArraySlice<int64> n-d 右辺拡張係数配列
feature_group_count int64 特徴グループの数
batch_group_count int64 バッチグループの数

n は空間ディメンションの数です。lhs 引数は、ベース領域を記述する(n+2)次元配列です。これは入力と呼ばれますが、もちろん右辺も入力です。ニューラル ネットワークでは、これらは入力アクティベーションです。n+2 ディメンションは次の順序で表示されます。

  • batch: このディメンションの各座標は、畳み込みが実行される独立した入力を表します。
  • z/depth/features: ベースエリア内の各(y,x)位置には、このディメンションに格納されるベクトルが関連付けられています。
  • spatial_dims: ウィンドウが移動するベース領域を定義する n 空間ディメンションを記述します。

rhs 引数は、畳み込みフィルタ/カーネル/ウィンドウを記述する(n+2)次元配列です。ディメンションの順序は次のとおりです。

  • output-z: 出力の z ディメンション。
  • input-z: このディメンションのサイズに feature_group_count を掛けた値は、lhs の z ディメンションのサイズと等しくする必要があります。
  • spatial_dims: ベースエリアを移動する n-d ウィンドウを定義する n 空間ディメンションを記述します。

window_strides 引数には、空間ディメンションで畳み込みウィンドウのストライドを指定します。たとえば、最初の空間ディメンションのストライドが 3 の場合、ウィンドウは、最初の空間インデックスが 3 で割り切れる座標にのみ配置できます。

padding 引数には、ベース領域に適用するゼロパディングの量を指定します。パディングの量は負の値にすることができます。負のパディングの絶対値は、畳み込みを行う前に指定されたディメンションから削除する要素の数を示します。padding[0] はディメンション y のパディングを指定し、padding[1] はディメンション x のパディングを指定します。各ペアでは、低いパディングが最初の要素、高いパディングが 2 番目の要素になります。低いパディングはインデックスが低い方向に適用され、高いパディングはインデックスが大きい方向に適用されます。たとえば、padding[1](2,3) の場合、2 番目の空間ディメンションの左側に 2 つのゼロ、右側に 3 つのゼロがパディングされます。パディングを使用することは、畳み込みを行う前に同じゼロ値を入力(lhs)に挿入することと同じです。

lhs_dilation 引数と rhs_dilation 引数には、各空間ディメンションで lhs と rhs に適用される拡大係数を指定します。空間ディメンションの拡張係数が d の場合、そのディメンションの各エントリの間に d-1 個の穴が暗黙的に配置され、配列のサイズが増加します。穴は no-op 値で埋められます。畳み込みの場合、これはゼロを意味します。

右辺の拡張は、アトラス畳み込みとも呼ばれます。詳しくは、tf.nn.atrous_conv2d をご覧ください。左辺の拡張は転置畳み込みとも呼ばれます。詳しくは、tf.nn.conv2d_transpose をご覧ください。

feature_group_count 引数(デフォルト値 1)は、グループ化された畳み込みに使用できます。feature_group_count は、入力特徴量と出力特徴量の両方のディメンションの除数である必要があります。feature_group_count が 1 より大きい場合、概念的には、入力特徴と出力特徴のディメンションと rhs 出力特徴のディメンションが、多くの feature_group_count グループに均等に分割され、各グループが連続する特徴のサブシーケンスで構成されます。入力特徴量のディメンション rhs は、lhs 入力特徴量のディメンションを feature_group_count で除算した値と等しくする必要があります(つまり、入力特徴量のグループのサイズがすでに設定されています)。i 番目のグループは、多くの個別の畳み込みの feature_group_count を計算するために一緒に使用されます。これらの畳み込みの結果は、出力特徴ディメンションで連結されます。

深度方向畳み込みの場合、feature_group_count 引数は入力特徴の次元に設定され、フィルタは [filter_height, filter_width, in_channels, channel_multiplier] から [filter_height, filter_width, 1, in_channels * channel_multiplier] に再構成されます。詳しくは、tf.nn.depthwise_conv2d をご覧ください。

batch_group_count(デフォルト値 1)引数は、バックプロパゲーション中にグループ化されたフィルタに使用できます。batch_group_count は、lhs(入力)バッチ ディメンションのサイズの除数である必要があります。batch_group_count が 1 より大きい場合、出力バッチ ディメンションのサイズは input batch / batch_group_count にする必要があります。batch_group_count は、出力特徴サイズの除数である必要があります。

出力シェイプのディメンションは次の順序になります。

  • batch: このディメンションのサイズに batch_group_count を掛けた値は、lhs の batch ディメンションのサイズと等しくする必要があります。
  • z: カーネルの output-z と同じサイズ(rhs)。
  • spatial_dims: 畳み込みウィンドウの有効な配置ごとに 1 つの値。

上の図は、batch_group_count フィールドの仕組みを示しています。つまり、各 lhs バッチを batch_group_count グループにスライスし、出力特徴に対しても同様のことを行います。次に、これらのグループごとにペアワイズ畳み込みを行い、出力特徴ディメンションに沿って出力を連結します。他のすべてのディメンション(特徴と空間)のオペレーション セマンティクスは同じです。

畳み込みウィンドウの有効な配置は、ストライドと、パディング後のベース領域のサイズによって決まります。

畳み込みの動作を説明するには、2D 畳み込みを検討し、出力内の固定の batchzyx 座標を選択します。ここで、(y,x) はベース領域内のウィンドウの角の位置です(空間の寸法の解釈に応じて、左上など)。ベース領域から取得した 2D ウィンドウが作成されました。各 2D ポイントは 1D ベクトルに関連付けられているため、3D ボックスが作成されます。畳み込みカーネルから、出力座標 z を固定したので、3D ボックスも生成されます。2 つのボックスのディメンションは同じであるため、2 つのボックス間の要素ごとの積の合計を取得できます(ドット積に似ています)。これが出力値です。

output-z が5 の場合、ウィンドウの各位置で、出力の z ディメンションに 5 つの値が出力されます。これらの値は、畳み込みカーネルのどの部分が使用されるかによって異なります。output-z 座標ごとに使用される値の 3D ボックスが別にあります。つまり、5 つの個別の畳み込みで、それぞれに異なるフィルタを使用すると考えることができます。

パディングとストライドを使用した 2D 畳み込みの擬似コードは次のとおりです。

for (b, oz, oy, ox) {  // output coordinates
  value = 0;
  for (iz, ky, kx) {  // kernel coordinates and input z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if ((iy, ix) inside the base area considered without padding) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

XlaBuilder::ConvertElementType もご覧ください。

C++ の要素ごとの static_cast と同様に、データシェイプからターゲット シェイプへの要素ごとの変換オペレーションを実行します。ディメンションは一致している必要があります。変換は要素単位で行われます。たとえば、s32 要素は s32 から f32 への変換ルーティンを介して f32 要素になります。

ConvertElementType(operand, new_element_type)

引数 タイプ セマンティクス
operand XlaOp サイズ D の型 T の配列
new_element_type PrimitiveType タイプ U

オペランドとターゲット シェイプのディメンションは一致している必要があります。ソース要素と宛先要素の型はタプルにできません。

T=s32 から U=f32 への変換では、最近接偶数丸めなどの正規化整数から浮動小数点への変換ルーティンが実行されます。

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

合計計算で AllReduce を実行します。

CustomCall

XlaBuilder::CustomCall もご覧ください。

計算内でユーザー提供の関数を呼び出します。

CustomCall(target_name, args..., shape)

引数 タイプ セマンティクス
target_name string 関数名。このシンボル名をターゲットとする呼び出し命令が生成されます。
args N 個の XlaOp のシーケンス 任意の型の N 個の引数。関数に渡されます。
shape Shape 関数の出力シェイプ

関数シグネチャは、引数の arity や型に関係なく同じです。

extern "C" void target_name(void* out, void** in);

たとえば、CustomCall が次のように使用されている場合:

let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };

CustomCall("myfunc", {x, y}, f32[3x3])

myfunc の実装例を次に示します。

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

ユーザー提供の関数には副作用がなく、その実行がべき等である必要があります。

ドット

XlaBuilder::Dot もご覧ください。

Dot(lhs, rhs)

引数 タイプ セマンティクス
lhs XlaOp 型 T の配列
rhs XlaOp 型 T の配列

このオペレーションの正確なセマンティクスは、オペランドのランクによって異なります。

入力 出力 セマンティクス
ベクトル [n] dot ベクトル [n] スカラー ベクトルのドット積
行列 [m x k] dot ベクトル [k] ベクトル [m] 行列とベクトルの乗算
行列 [m x k] dot 行列 [k x n] 行列 [m x n] 行列乗算

このオペレーションは、lhs の 2 番目のディメンション(または 1 つのディメンションがある場合は 1 番目)と rhs の 1 番目のディメンションに対して積の合計を実行します。これらは「圧縮」されたディメンションです。lhsrhs の圧縮されたディメンションは同じサイズにする必要があります。実際には、ベクトル間のドット積、ベクトル/行列の乗算、行列/行列の乗算を実行するために使用できます。

DotGeneral

XlaBuilder::DotGeneral もご覧ください。

DotGeneral(lhs, rhs, dimension_numbers)

引数 タイプ セマンティクス
lhs XlaOp 型 T の配列
rhs XlaOp 型 T の配列
dimension_numbers DotDimensionNumbers 契約とバッチのディメンション番号

ドットと似ていますが、lhsrhs の両方に、圧縮とバッチのディメンション番号を指定できます。

DotDimensionNumbers フィールド タイプ セマンティクス
lhs_contracting_dimensions repeated int64 lhs 契約ディメンション番号
rhs_contracting_dimensions repeated int64 rhs 契約ディメンション番号
lhs_batch_dimensions repeated int64 lhs バッチ ディメンション番号
rhs_batch_dimensions repeated int64 rhs バッチ ディメンション番号

DotGeneral は、dimension_numbers で指定された収縮ディメンションの積の合計を実行します。

lhsrhs の関連する契約ディメンション番号は同じである必要はありませんが、同じディメンション サイズにする必要があります。

ディメンション番号を短縮した例:

lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }

lhsrhs の関連するバッチ ディメンション番号のディメンション サイズは同じである必要があります。

バッチ ディメンション番号の例(バッチサイズ 2、2x2 行列):

lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }

rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
入力 出力 セマンティクス
[b0, m, k] dot [b0, k, n] [b0, m, n] バッチ matmul
[b0, b1, m, k] dot [b0, b1, k, n] [b0, b1, m, n] バッチ matmul

したがって、結果のディメンション番号は、バッチ ディメンション、lhs 非収縮/非バッチ ディメンション、rhs 非収縮/非バッチ ディメンションの順に始まります。

DynamicSlice

XlaBuilder::DynamicSlice もご覧ください。

DynamicSlice は、動的 start_indices の入力配列からサブ配列を抽出します。各ディメンションのスライスのサイズは size_indices で渡されます。これは、各ディメンションのスライス区間の終了ポイント([開始、開始 + サイズ])を指定します。start_indices の形状は 1 次元で、ディメンション サイズは operand のディメンション数にする必要があります。

DynamicSlice(operand, start_indices, size_indices)

引数 タイプ セマンティクス
operand XlaOp 型 T の N 次元配列
start_indices N 個の XlaOp のシーケンス 各ディメンションのスライスの開始インデックスを含む N 個のスカラー整数のリスト。0 以上の値を指定してください。
size_indices ArraySlice<int64> 各ディメンションのスライスサイズを含む N 個の整数のリスト。各値は 0 より大きく、開始値 + サイズがディメンションのサイズ以下である必要があります。これにより、ディメンションのサイズのモジュロでラップされるのを回避できます。

有効なスライス インデックスは、スライスを実行する前に [1, N) 内の各インデックス i に次の変換を適用することで計算されます。

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])

これにより、抽出されたスライスはオペランド配列に対して常に境界内に収まります。変換が適用される前にスライスが境界内にある場合、変換は適用されません。

1 次元の例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}

2 次元の例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0,  8.0},
{10.0, 11.0} }

DynamicUpdateSlice

XlaBuilder::DynamicUpdateSlice もご覧ください。

DynamicUpdateSlice は、入力配列 operand の値である結果を生成します。スライス updatestart_indices で上書きされます。update の形状によって、更新される結果のサブ配列の形状が決まります。start_indices の形状は 1 次元で、ディメンション サイズは operand のディメンション数にする必要があります。

DynamicUpdateSlice(operand, update, start_indices)

引数 タイプ セマンティクス
operand XlaOp 型 T の N 次元配列
update XlaOp スライスの更新を含む T 型の N 次元配列。更新シェイプの各ディメンションは 0 より大きく、開始位置 + 更新量は各ディメンションのオペランドサイズ以下にする必要があります。これにより、範囲外の更新インデックスが生成されるのを回避できます。
start_indices N 個の XlaOp のシーケンス 各ディメンションのスライスの開始インデックスを含む N 個のスカラー整数のリスト。0 以上の値を指定してください。

有効なスライス インデックスは、スライスを実行する前に [1, N) 内の各インデックス i に次の変換を適用することで計算されます。

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])

これにより、更新されたスライスはオペランド配列に対して常にインバウンドになります。変換が適用される前にスライスが境界内にある場合、変換は適用されません。

1 次元の例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}

2 次元の例:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0,  13.0},
{14.0,  15.0},
{16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
{ {0.0,  1.0,  2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }

要素ごとのバイナリ算術演算

XlaBuilder::Add もご覧ください。

要素ごとのバイナリ算術演算のセットがサポートされています。

Op(lhs, rhs)

ここで、OpAdd(加算)、Sub(減算)、Mul(乗算)、Div(除算)、Pow(べき乗)、Rem(剰余)、Max(最大値)、Min(最小値)、And(論理 AND)、Or(論理 OR)、Xor(論理 XOR)、ShiftLeft(左シフト)、ShiftRightArithmetic(算術右シフト)、ShiftRightLogical(論理右シフト)、Atan2(2 つの引数を持つ逆正弦)、Complex(実部と虚部を複素数に結合)のいずれかです。

引数 タイプ セマンティクス
lhs XlaOp 左側のオペランド: T 型の配列
rhs XlaOp 右側のオペランド: T 型の配列

引数のシェイプは類似しているか、互換性がある必要があります。シェイプの互換性について詳しくは、ブロードキャストのドキュメントをご覧ください。演算の結果は、2 つの入力配列をブロードキャストした結果のシェイプになります。このバリアントでは、オペランドの 1 つがスカラーでない限り、異なるランクの配列間の演算はサポートされていません。

OpRem の場合、結果の符号は被除数から取得され、結果の絶対値は常に除数の絶対値より小さくなります。

整数の除算オーバーフロー(ゼロによる符号付き/符号なしの除算/余り、または INT_SMIN-1 の符号付き除算/余り)は、実装定義の値を生成します。

これらのオペレーションには、異なるディメンションのブロードキャストをサポートする代替のバリアントがあります。

Op(lhs, rhs, broadcast_dimensions)

ここで、Op は上記と同じです。このオペレーションのバリエーションは、異なるランクの配列間の算術演算(ベクトルに行列を追加するなど)に使用する必要があります。

追加の broadcast_dimensions オペランドは、低次元オペランドのディメンション数を高次元オペランドのディメンション数まで拡張するために使用される整数のスライスです。broadcast_dimensions は、低次元シェイプのディメンションを高次元シェイプのディメンションにマッピングします。拡張されたシェイプのマッピングされていないディメンションは、サイズ 1 のディメンションで埋められます。次に、不完全なディメンション ブロードキャストは、これらの不完全なディメンションに沿ってシェイプをブロードキャストし、両方のオペランドのシェイプを等しくします。セマンティクスについては、ブロードキャスト ページで詳しく説明しています。

要素ごとの比較演算

XlaBuilder::Eq もご覧ください。

標準の要素ごとのバイナリ比較演算がサポートされています。浮動小数点型を比較する場合は、標準 IEEE 754 浮動小数点比較セマンティクスが適用されます。

Op(lhs, rhs)

ここで、Op は、Eq(等しい)、Ne(等しくない)、Ge(以上)、Gt(より大きい)、Le(以下)、Lt(未満)のいずれかです。別の演算子セットである EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder、LtTotalOrder は、-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN を適用することで、浮動小数点数に対する全順序をサポートする点を除き、同じ機能を提供します。

引数 タイプ セマンティクス
lhs XlaOp 左側のオペランド: T 型の配列
rhs XlaOp 右側のオペランド: T 型の配列

引数のシェイプは類似しているか、互換性がある必要があります。シェイプの互換性について詳しくは、ブロードキャストのドキュメントをご覧ください。演算の結果のシェイプは、要素型 PRED の 2 つの入力配列をブロードキャストした結果です。このバリアントでは、オペランドの 1 つがスカラーでない限り、異なるランクの配列間の演算はサポートされていません。

これらのオペレーションには、異なるディメンションのブロードキャストをサポートする代替のバリアントがあります。

Op(lhs, rhs, broadcast_dimensions)

ここで、Op は上記と同じです。この演算のバリエーションは、異なるランクの配列間の比較演算(ベクトルに行列を追加するなど)に使用する必要があります。

追加の broadcast_dimensions オペランドは、オペランドのブロードキャストに使用するディメンションを指定する整数のスライスです。セマンティクスについては、ブロードキャスト ページで詳しく説明しています。

要素ごとの単項関数

XlaBuilder は、次の要素ごとの単項関数をサポートしています。

Abs(operand) 要素単位の絶対値 x -> |x|

Cbrt(operand) 要素ごとの立方根演算 x -> cbrt(x)

Ceil(operand) 要素ごとの天井関数 x -> ⌈x⌉

Clz(operand) 要素ごとに先頭のゼロをカウントします。

Cos(operand) 要素ごとのコサイン x -> cos(x)

Erf(operand) 要素ごとの誤差関数 x -> erf(x)

\(\text{erf}(x) = \frac{2}{\sqrt{\pi} }\int_0^x e^{-t^2} \, dt\)。

Exp(operand) 要素ごとの自然指数 x -> e^x

Expm1(operand) 要素ごとの自然指数マイナス 1 x -> e^x - 1

Floor(operand) 要素ごとの下限 x -> ⌊x⌋

Imag(operand) 複素(または実数)シェイプの要素ごとの虚部。x -> imag(x)。オペランドが浮動小数点型の場合は 0 を返します。

IsFinite(operand) operand の各要素が有限である(正または負の無限大ではなく、NaN でもない)かどうかをテストします。入力と同じ形状の PRED 値の配列を返します。対応する入力要素が有限である場合にのみ、各要素が true になります。

Log(operand) 要素ごとの自然対数 x -> ln(x)

Log1p(operand) 要素ごとシフトされた自然対数 x -> ln(1+x)

Logistic(operand) 要素ごとのロジスティック関数の計算 x -> logistic(x)

Neg(operand) 要素ごとの否定 x -> -x

Not(operand) 要素ごとの論理否定 x -> !(x)

PopulationCount(operand) operand の各要素で設定されているビット数を計算します。

Real(operand) 複素(または実数)シェイプの要素ごとの実部。x -> real(x)。オペランドが浮動小数点型の場合、同じ値を返します。

Round(operand) 要素ごとの丸め、同数の場合はゼロから遠ざかるように丸められます。

RoundNearestEven(operand) 要素ごとの丸め、最も近い偶数に丸められます。

Rsqrt(operand) 平方根演算 x -> 1.0 / sqrt(x) の要素ごとの逆数。

Sign(operand) 要素ごとの符号演算 x -> sgn(x)

\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]

要素型 operand の比較演算子を使用します。

Sin(operand) 要素ごとの正弦 x -> sin(x)

Sqrt(operand) 要素ごとの平方根演算 x -> sqrt(x)

Tan(operand) 要素ごとの接線 x -> tan(x)

Tanh(operand) 要素ごとの双曲線正接 x -> tanh(x)

引数 タイプ セマンティクス
operand XlaOp 関数のオペランド

この関数は operand 配列内の各要素に適用され、同じシェイプの配列が生成されます。operand はスカラー(0 次元)にすることができます。

Fft

XLA FFT オペレーションは、実数と複素数の入力/出力に対して正規化と逆正規化の 2 つのフーリエ変換を実装します。最大 3 つの軸での多変量 FFT がサポートされています。

XlaBuilder::Fft もご覧ください。

引数 タイプ セマンティクス
operand XlaOp フーリエ変換する配列。
fft_type FftType 下の表をご覧ください。
fft_length ArraySlice<int64> 変換される軸の時間領域の長さ。これは、RFFT(fft_length=[16]) の出力形状が RFFT(fft_length=[17]) と同じであるため、IRFFT が最も内側の軸のサイズを適切に調整するために特に必要です。
FftType セマンティクス
FFT 複素数から複素数への FFT を前方処理します。形状は変更されません。
IFFT 複素数から複素数への逆 FFT。形状は変更されません。
RFFT 実数から複素数への FFT を前方処理します。fft_length[-1] がゼロ以外の値の場合、最内側の軸の形状は fft_length[-1] // 2 + 1 に縮小され、ナイキスト周波数を超える変換された信号の逆コンジュゲート部分が省略されます。
IRFFT 実数から複素数への逆 FFT(複素数を受け取り、実数を返します)。fft_length[-1] がゼロ以外の値の場合、最内側の軸の形状は fft_length[-1] に拡張され、1 から fft_length[-1] // 2 + 1 へのエントリの逆共役から、ナイキスト周波数を超える変換された信号の部分を推測します。

多次元 FFT

複数の fft_length が指定されている場合、これは、最も内側の各軸に FFT オペレーションのカスケードを適用することと同じです。実数から複素数への変換と複素数から実数への変換の場合、最も内側の軸変換が(実質的に)最初に実行されます(RFFT の場合、IRFFT の場合は最後)。そのため、最も内側の軸のサイズが変更されます。他の軸変換は複雑>複雑になります。

実装の詳細

CPU FFT は、Eigen の TensorFFT をベースにしています。GPU FFT は cuFFT を使用します。

収集

XLA 集約演算は、入力配列の複数のスライス(各スライスのランタイム オフセットが異なる場合があります)をつなぎ合わせます。

一般セマンティクス

XlaBuilder::Gather もご覧ください。より直感的な説明については、以下の「簡単な説明」をご覧ください。

gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)

引数 タイプ セマンティクス
operand XlaOp 収集元の配列。
start_indices XlaOp 収集するスライスの開始インデックスを含む配列。
index_vector_dim int64 開始インデックスが「含まれる」start_indices 内のディメンション。詳細については、以下をご覧ください。
offset_dims ArraySlice<int64> オペランドからスライスされた配列にオフセットする出力シェイプのディメンションのセット。
slice_sizes ArraySlice<int64> slice_sizes[i] は、ディメンション i のスライスの境界です。
collapsed_slice_dims ArraySlice<int64> 各スライスで閉じられたディメンションのセット。これらのディメンションのサイズは 1 にする必要があります。
start_index_map ArraySlice<int64> start_indices のインデックスをオペランドの有効なインデックスにマッピングする方法を表すマップ。
indices_are_sorted bool 呼び出し元によってインデックスが並べ替えられることが保証されているかどうか。

便宜上、出力配列のディメンションには offset_dims ではなく batch_dims というラベルを付けています。

出力は、batch_dims.size + offset_dims.size 次元の配列です。

operand.rankoffset_dims.sizecollapsed_slice_dims.size の合計と等しい必要があります。また、slice_sizes.sizeoperand.rank と等しい必要があります。

index_vector_dimstart_indices.rank の場合、start_indices には末尾に 1 ディメンションがあると暗黙的にみなされます(つまり、start_indices の型が [6,7] で、index_vector_dim2 の場合、start_indices の型は暗黙的に [6,7,1] とみなされます)。

ディメンション i に沿った出力配列の境界は、次のように計算されます。

  1. ibatch_dims に存在する場合(つまり、一部の k に対して batch_dims[k] と等しい場合)、index_vector_dim をスキップして start_indices.shape から対応するディメンションの境界を選択します(つまり、k < index_vector_dim の場合は start_indices.shape.dims[k] を選択し、それ以外の場合は start_indices.shape.dims[k+1] を選択します)。

  2. ioffset_dims に存在する場合(つまり、ある k に対して offset_dims[k] と等しい場合)、collapsed_slice_dims を考慮した後で slice_sizes から対応する境界を選択します(つまり、adjusted_slice_sizes[k] を選択します。ここで、adjusted_slice_sizes はインデックス collapsed_slice_dims の境界が削除された slice_sizes です)。

正式には、特定の出力インデックス Out に対応するオペランド インデックス In は次のように計算されます。

  1. G = { batch_dims 内の kOut[k] } とします。G を使用して、S[i] = start_indices[Combine(G, i)] となるようにベクトル S をスライスします。ここで、Combine(A, b) は、A の index_vector_dim 位置に b を挿入します。これは、G が空の場合でも明確に定義されています。G が空の場合、S = start_indices です。

  2. start_index_map を使用して S を分散し、S を使用して operand に開始インデックス Sin を作成します。具体的には次のようになります。

    1. k < start_index_map.size の場合、Sin[start_index_map[k]] = S[k]。

    2. Sin[_] = 0(それ以外の場合)。

  3. collapsed_slice_dims セットに従って Out のオフセット ディメンションでインデックスを分散し、operand にインデックス Oin を作成します。具体的には次のようになります。

    1. k < offset_dims.size の場合、Oin[remapped_offset_dims(k)] = Out[offset_dims[k]](remapped_offset_dims は後で定義されます)。

    2. Oin[_] = 0(それ以外の場合)。

  4. InOin + Sin です。ここで、+ は要素単位の加算です。

remapped_offset_dims は、ドメイン [0offset_dims.size] と範囲 [0operand.rank] \ collapsed_slice_dims の単調関数です。たとえば、offset_dims.size4operand.rank6collapsed_slice_dims が {0, 2} の場合、remapped_offset_dims は {01132435} です。

indices_are_sorted が true に設定されている場合、XLA は start_indices がユーザーによって並べ替えられていると想定できます(start_index_map に従って値を分散した、昇順)。そうでない場合、セマンティクスは実装定義です。

非公式の説明と例

非公式には、出力配列内のすべてのインデックス Out は、オペランド配列内の要素 E に対応し、次のように計算されます。

  • Out のバッチ ディメンションを使用して、start_indices から開始インデックスを検索します。

  • start_index_map を使用して、開始インデックス(サイズが operand.rank より小さい場合があります)を operand の「完全な」開始インデックスにマッピングします。

  • 完全な開始インデックスを使用して、サイズ slice_sizes のスライスを動的にスライスします。

  • collapsed_slice_dims ディメンションを閉じて、スライスの形状を変更します。すべての圧縮スライス ディメンションの境界は 1 である必要があるため、この再シェイプは常に有効です。

  • Out のオフセット ディメンションを使用してこのスライスにインデックスを付け、出力インデックス Out に対応する入力要素 E を取得します。

以降のすべての例では、index_vector_dimstart_indices.rank1 に設定されています。index_vector_dim のより興味深い値は、オペレーションを根本的に変更するものではありませんが、視覚的な表現をより複雑にします。

上記のすべてがどのように連携しているかを直感的に理解するには、[16,11] 配列から 5 つのスライス シェイプ [8,6] を収集する例を見てみましょう。[16,11] 配列内のスライスの位置は、形状 S64[2] のインデックス ベクトルとして表すことができるため、5 つの位置のセットは S64[5,2] 配列として表すことができます。

集約オペレーションの動作は、出力シェイプのインデックスである [GO0O1] を受け取り、次のように入力配列の要素にマッピングするインデックス変換として表すことができます。

まず、G を使用して、集計インデックス アレイから(XY)ベクトルを選択します。出力配列のインデックス [G,O0,O1] の要素は、入力配列のインデックス [X+O0,Y+O1] の要素になります。

slice_sizes[8,6] で、O0 と O1 の範囲を決定します。これにより、スライスの境界が決まります。

この集計オペレーションは、G を一括ディメンションとして使用する一括動的スライスとして機能します。

収集インデックスは多次元にすることができます。たとえば、上記の例のより一般的なバージョンでは、[4,5,2] 型の「集計インデックス」配列を使用して、次のようにインデックスを変換します。

これも、バッチ動的スライス G0 として機能し、G1 はバッチ ディメンションとして機能します。スライスサイズは引き続き [8,6] です。

XLA の集計オペレーションは、上記の非公式なセマンティクスを次のように一般化します。

  1. 出力シェイプのどのディメンションがオフセット ディメンションであるかを構成できます(最後の例では、O0O1 を含むディメンション)。出力バッチ ディメンション(最後の例では G0G1 を含むディメンション)は、オフセット ディメンションではない出力ディメンションとして定義されます。

  2. 出力シェイプに明示的に存在する出力オフセット ディメンションの数は、入力ディメンションの数よりも少なくなる場合があります。これらの「欠落している」ディメンションは、collapsed_slice_dims として明示的にリストされており、スライスサイズが 1 である必要があります。スライスサイズが 1 であるため、有効なインデックスは 0 のみであり、省略しても曖昧さは生じません。

  3. [Gather Indices] 配列から抽出されたスライス(最後の例の XY)には、入力配列のディメンション数よりも要素が少ない場合があります。明示的なマッピングにより、入力と同じ数のディメンションを持つようにインデックスを拡張する方法が指定されます。

最後の例として、(2)と(3)を使用して tf.gather_nd を実装します。

G0G1 は、通常どおり、集計インデックス配列から開始インデックスをスライスするために使用されます。ただし、開始インデックスには X という要素が 1 つだけあります。同様に、値が O0 の出力オフセット インデックスは 1 つだけです。ただし、入力配列のインデックスとして使用される前に、これらのインデックスは「Gather Index Mapping」(正式な記述では start_index_map)と「Offset Mapping」(正式な記述では remapped_offset_dims)に従って拡張され、それぞれ [X,0] と [0,O0] になり、合計 [X,O0] になります。つまり、出力インデックス [G0,G1,O0] は入力インデックス [GatherIndices[G0,G1,0],O0] にマッピングされ、tf.gather_nd のセマンティクスが得られます。

このケースの slice_sizes[1,11] です。直感的に、これは、集計インデックス 配列内のすべてのインデックス X が行全体を選択し、結果がこれらの行の連結になることを意味します。

GetDimensionSize

XlaBuilder::GetDimensionSize もご覧ください。

オペランドの指定されたディメンションのサイズを返します。オペランドは配列型である必要があります。

GetDimensionSize(operand, dimension)

引数 タイプ セマンティクス
operand XlaOp n 次元の入力配列
dimension int64 ディメンションを指定する [0, n) の範囲内の値

SetDimensionSize

XlaBuilder::SetDimensionSize もご覧ください。

XlaOp の指定されたディメンションの動的サイズを設定します。オペランドは配列型である必要があります。

SetDimensionSize(operand, size, dimension)

引数 タイプ セマンティクス
operand XlaOp n 次元の入力配列。
size XlaOp 実行時の動的サイズを表す int32。
dimension int64 ディメンションを指定する [0, n) の範囲内の値。

オペランドを結果として渡し、動的ディメンションをコンパイラが追跡します。

パディングされた値は、ダウンストリームの減算オペレーションによって無視されます。

let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;

// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);

// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);

// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);

GetTupleElement

XlaBuilder::GetTupleElement もご覧ください。

コンパイル時定数値を持つタプルのインデックス。

値はコンパイル時定数である必要があります。これにより、シェイプ推論で結果の値の型を決定できます。

これは C++ の std::get<int N>(t) に似ています。コンセプト的には次のようになります。

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.

tf.tuple もご覧ください。

インフィード

XlaBuilder::Infeed もご覧ください。

Infeed(shape)

引数 タイプ セマンティクス
shape Shape インフィード インターフェースから読み取られるデータの形式。シェイプの layout フィールドは、デバイスに送信されるデータのレイアウトと一致するように設定する必要があります。一致しない場合、動作は未定義です。

デバイスの暗黙的なインフィード ストリーミング インターフェースから単一のデータアイテムを読み取り、データを指定されたシェイプとそのレイアウトとして解釈し、データの XlaOp を返します。計算では複数のインフィード オペレーションを使用できますが、インフィード オペレーションの間には完全な順序が必要です。たとえば、次のコードの 2 つの Infeed には、while ループ間に依存関係があるため、合計順序があります。

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

ネストされたタプル シェイプはサポートされていません。空のタプル シェイプの場合、インフィード オペレーションは事実上無効になり、デバイスのインフィードからデータを読み取らずに処理が続行されます。

Iota

XlaBuilder::Iota もご覧ください。

Iota(shape, iota_dimension)

ホスト転送のサイズが大きくなる可能性がある代わりに、デバイスに定数リテラルをビルドします。指定されたシェイプの配列を作成し、指定されたディメンションに沿ってゼロから 1 ずつ増加する値を保持します。浮動小数点型の場合、生成される配列は ConvertElementType(Iota(...)) と同等です。ここで、Iota は整数型で、変換は浮動小数点型です。

引数 タイプ セマンティクス
shape Shape Iota() によって作成された配列の形状
iota_dimension int64 インクリメントするディメンション。

たとえば、Iota(s32[4, 8], 0)

  [[0, 0, 0, 0, 0, 0, 0, 0 ],
   [1, 1, 1, 1, 1, 1, 1, 1 ],
   [2, 2, 2, 2, 2, 2, 2, 2 ],
   [3, 3, 3, 3, 3, 3, 3, 3 ]]

返品可能(返品手数料: Iota(s32[4, 8], 1)

  [[0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ]]

地図

XlaBuilder::Map もご覧ください。

Map(operands..., computation)

引数 タイプ セマンティクス
operands N 個の XlaOp のシーケンス T0..T{N-1} 型の N 個の配列
computation XlaComputation T 型の N 個のパラメータと任意の型の M 個のパラメータを持つ T_0, T_1, .., T_{N + M -1} -> S 型の計算
dimensions int64 配列 地図のディメンションの配列

指定された operands 配列にスカラー関数を適用し、同じディメンションの配列を生成します。各要素は、入力配列内の対応する要素に適用されたマッピング関数の結果です。

マッピングされた関数は任意の計算で、スカラー型 T の N 個の入力と、型 S の 1 つの出力を持つという制限があります。出力は、要素型 T が S に置き換えられることを除き、オペランドと同じディメンションになります。

たとえば、Map(op1, op2, op3, computation, par1) は入力配列の各(多次元)インデックスで elem_out <- computation(elem1, elem2, elem3, par1) をマッピングして、出力配列を生成します。

OptimizationBarrier

最適化パスが境界を越えて計算を移動できないようにします。

バリアの出力に依存する演算子よりも前に、すべての入力が評価されるようにします。

パッド

XlaBuilder::Pad もご覧ください。

Pad(operand, padding_value, padding_config)

引数 タイプ セマンティクス
operand XlaOp T 型の配列
padding_value XlaOp 追加されたパディングを埋める T 型のスカラー
padding_config PaddingConfig 両端のパディング量(低、高)と各ディメンションの要素間のパディング量

指定された operand 配列を拡張します。配列の周囲と配列の要素間に指定された padding_value をパディングします。padding_config には、各ディメンションのエッジ パディングと内部パディングの量を指定します。

PaddingConfigPaddingConfigDimension の繰り返しフィールドで、各ディメンションに edge_padding_lowedge_padding_highinterior_padding の 3 つのフィールドが含まれています。

edge_padding_lowedge_padding_high は、各ディメンションの下端(インデックス 0 の隣)と上端(最大インデックスの隣)に追加されるパディングの量をそれぞれ指定します。エッジ パディングの量は負の値にすることができます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。

interior_padding には、各ディメンションの 2 つの要素間に追加されるパディングの量を指定します。負の値は指定できません。内部パディングは論理的にはエッジパディングの前に行われるため、エッジパディングが負の場合、内部パディングされたオペランドから要素が削除されます。

エッジ パディング ペアがすべて(0, 0)で、内部パディング値がすべて 0 の場合、このオペレーションは no-op です。次の図は、2 次元配列のさまざまな edge_padding 値と interior_padding 値の例を示しています。

Recv

XlaBuilder::Recv もご覧ください。

Recv(shape, channel_handle)

引数 タイプ セマンティクス
shape Shape 受信するデータの形状
channel_handle ChannelHandle 送受信ペアごとの一意の識別子

同じチャネル ハンドルを共有する別のコンピューティングの Send 命令から、指定されたシェイプのデータを受け取ります。受信したデータの XlaOp を返します。

Recv オペレーションのクライアント API は、同期通信を表します。ただし、この命令は内部で 2 つの HLO 命令(RecvRecvDone)に分解され、非同期データ転送が可能になります。HloInstruction::CreateRecvHloInstruction::CreateRecvDone もご覧ください。

Recv(const Shape& shape, int64 channel_id)

同じ channel_id を持つ Send 命令からデータを受信するために必要なリソースを割り当てます。割り当てられたリソースのコンテキストを返します。このコンテキストは、次の RecvDone 命令でデータ転送の完了を待機するために使用されます。コンテキストは {受信バッファ(シェイプ)、リクエスト ID(U32)} の タプルであり、RecvDone 命令でのみ使用できます。

RecvDone(HloInstruction context)

Recv 命令によって作成されたコンテキストを受け取り、データ転送が完了するまで待機し、受信したデータを返します。

削減

XlaBuilder::Reduce もご覧ください。

リダクション関数を 1 つ以上の配列に並行して適用します。

Reduce(operands..., init_values..., computation, dimensions)

引数 タイプ セマンティクス
operands N 個の XlaOp のシーケンス T_0, ..., T_{N-1} 型の N 個の配列。
init_values N 個の XlaOp のシーケンス T_0, ..., T_{N-1} 型の N 個のスカラー。
computation XlaComputation T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) の計算。
dimensions int64 配列 削減するディメンションの順序なし配列。

ここで

  • N は 1 以上の値にする必要があります。
  • 計算は「おおよそ」結合的である必要があります(後述)。
  • 入力配列はすべて同じディメンションである必要があります。
  • すべての初期値は、computation で ID を形成する必要があります。
  • N = 1 の場合、Collate(T)T です。
  • N > 1 の場合、Collate(T_0, ..., T_{N-1})T 型の N 要素のタプルです。

このオペレーションは、各入力配列の 1 つ以上のディメンションをスカラーに減らします。返される各配列のディメンション数は number_of_dimensions(operand) - len(dimensions) です。op の出力は Collate(Q_0, ..., Q_N) です。ここで、Q_iT_i 型の配列です。そのディメンションは後述します。

異なるバックエンドで、減算計算を再関連付けることができます。加算などの一部の減算関数は浮動小数点数に対して結合的ではないため、数値の差異が生じる可能性があります。ただし、データの範囲が制限されている場合、ほとんどの実用的な用途では、浮動小数点加算は結合性に非常に近くなります。

値が [10, 11, 12, 13] の単一の 1D 配列の 1 つのディメンション全体を、減算関数 f(これは computation)で減算する場合、次のように計算できます。

f(10, f(11, f(12, f(init_value, 13)))

他にも次のような多くの方法があります。

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))

次の例は、初期値が 0 の減算計算として合計を使用する減算の実装方法の概要を示しています。

result_shape <- remove all dims in dimensions from operand_shape

# Iterate over all elements in result_shape. The number of r's here is equal
# to the number of dimensions of the result.
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # Initialize this result element
  result[r0, r1...] <- 0

  # Iterate over all the reduction dimensions
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # Increment the result element with the value of the operand's element.
    # The index of the operand's element is constructed from all ri's and di's
    # in the right order (by construction ri's and di's together index over the
    # whole operand shape).
    result[r0, r1...] += operand[ri... di]

2 次元配列(行列)を減らす例を次に示します。このシェイプには 2 つのディメンション(サイズ 2 のディメンション 0 とサイズ 3 のディメンション 1)があります。

「add」関数を使用して次元 0 または 1 を減らした場合の結果:

どちらの減算結果も 1 次元配列です。この図では、視認性を高めるために、1 つを列として、もう 1 つを行として示しています。

より複雑な例として、3D 配列を次に示します。次元数は 3 で、次元 0 のサイズは 4、次元 1 のサイズは 2、次元 2 のサイズは 3 です。単純にするため、値 1 ~ 6 はディメンション 0 全体に複製されます。

2D の例と同様に、1 つのディメンションのみを削減できます。たとえば、ディメンション 0 を削減すると、ディメンション 0 のすべての値がスカラーに折りたたまれた 2 次元配列が得られます。

|  4   8  12 |
| 16  20  24 |

ディメンション 2 を削減すると、ディメンション 2 のすべての値がスカラーに折りたたまれた 2 次元配列も得られます。

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

入力の残りのディメンション間の相対順序は出力で保持されますが、ディメンションの数が変わるため、一部のディメンションに新しい番号が割り当てられる場合があります。

複数のディメンションを削減することもできます。追加減算ディメンション 0 と 1 を指定すると、1 次元配列 [20, 28, 36] が生成されます。

3D 配列をすべてのディメンションにわたって減算すると、スカラー 84 が生成されます。

可変引数 Reduce

N > 1 の場合、reduce 関数の適用はすべての入力に同時に適用されるため、少し複雑になります。オペランドは次の順序で計算に渡されます。

  • 最初のオペランドの実行された減算値
  • ...
  • オペランド N 番目の実行時の減算値
  • 最初のオペランドの入力値
  • ...
  • 3 番目のオペランドの入力値

たとえば、次の減算関数は、1 次元配列の最大値と argmax を並列で計算するために使用できます。

f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
  if value >= max:
    return (value, index)
  else:
    return (max, argmax)

1 次元の入力配列 V = Float[N], K = Int[N] と初期値 I_V = Float, I_K = Int の場合、唯一の入力ディメンション全体で減算した結果 f_(N-1) は、次の再帰適用と同等です。

f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))

この減算を値の配列と連続インデックスの配列(iota)に適用すると、配列が同時に反復処理され、最大値と一致するインデックスを含むタプルが返されます。

ReducePrecision

XlaBuilder::ReducePrecision もご覧ください。

浮動小数点値を低精度形式(IEEE-FP16 など)に変換してから元の形式に戻す効果をモデル化します。低精度形式の指数と小数点以下ビットの数は任意で指定できますが、すべてのビットサイズがすべてのハードウェア実装でサポートされているわけではありません。

ReducePrecision(operand, mantissa_bits, exponent_bits)

引数 タイプ セマンティクス
operand XlaOp 浮動小数点型 T の配列。
exponent_bits int32 低精度形式の指数ビット数
mantissa_bits int32 低精度形式のマンシサビット数

結果は T 型の配列です。入力値は、指定された小数点以下ビット数で表せる値に切り上げられます(「偶数に切り上げ」セマンティクスを使用)。指数ビット数で指定された範囲を超える値は、正または負の無限大にクランプされます。NaN 値は保持されますが、正規の NaN 値に変換される場合があります。

精度が低い形式には、少なくとも 1 つの指数ビットが必要です(ゼロの仮数を持つ値と無限大の値を区別するため)。また、仮数ビットの数は正の整数にする必要があります。指数または小数点以下ビットの数が、型 T に対応する値を超える場合があります。その場合、変換の対応する部分は単に no-op になります。

ReduceScatter

XlaBuilder::ReduceScatter もご覧ください。

ReduceScatter は、all-reduce を効果的に実行し、結果を scatter_dimension に沿って shard_count ブロックに分割して結果を散布する集団演算です。レプリカ グループ内のレプリカ iith シャードを受信します。

ReduceScatter(operand, computation, scatter_dim, shard_count, replica_group_ids, channel_id)

引数 タイプ セマンティクス
operand XlaOp レプリカ間で減算する配列または配列の空でないタプル。
computation XlaComputation 削減の計算
scatter_dimension int64 散布するディメンション。
shard_count int64 scatter_dimension を分割するブロック数
replica_groups int64 のベクトルのベクトル 削減を実行するグループ
channel_id 省略可能な int64 モジュール間通信用のオプションのチャネル ID
  • operand が配列のタプルの場合、タプルの各要素に対して reduce-scatter が実行されます。
  • replica_groups は、縮小が実行されるレプリカ グループのリストです(現在のレプリカのレプリカ ID は ReplicaId を使用して取得できます)。各グループのレプリカの順序によって、all-reduce の結果が分散される順序が決まります。replica_groups は空であるか(この場合、すべてのレプリカが単一のグループに属します)、レプリカの数と同じ数の要素を含める必要があります。レプリカ グループが複数ある場合は、すべて同じサイズにする必要があります。たとえば、replica_groups = {0, 2}, {1, 3} はレプリカ 0213 の間で減算を実行し、結果を分散します。
  • shard_count は、各レプリカ グループのサイズです。これは、replica_groups が空の場合に必要です。replica_groups が空でない場合、shard_count は各レプリカ グループのサイズと等しくする必要があります。
  • channel_id はモジュール間の通信に使用されます。同じ channel_id を持つ reduce-scatter オペレーションのみが相互に通信できます。

出力シェイプは、scatter_dimensionshard_count 倍小さくなった入力シェイプです。たとえば、2 つのレプリカがあり、2 つのレプリカでオペランドの値がそれぞれ [1.0, 2.25][3.0, 5.25] の場合、scatter_dim0 であるこのオペレーションの出力値は、最初のレプリカでは [4.0]、2 番目のレプリカでは [7.5] になります。

ReduceWindow

XlaBuilder::ReduceWindow もご覧ください。

N 個の多次元配列のシーケンスの各ウィンドウ内のすべての要素に減算関数を適用し、単一の多次元配列または N 個の多次元配列のタプルを出力として生成します。各出力配列には、ウィンドウの有効な位置の数と同じ数の要素があります。プーリング レイヤは ReduceWindow として表すことができます。Reduce と同様に、適用された computation には常に左側の init_values が渡されます。

ReduceWindow(operands..., init_values..., computation, window_dimensions, window_strides, padding)

引数 タイプ セマンティクス
operands N XlaOps T_0,..., T_{N-1} 型の N 個の多次元配列のシーケンス。それぞれが、ウィンドウが配置されるベース領域を表します。
init_values N XlaOps 減算の N 個の開始値(N 個のオペランドごとに 1 つ)。詳しくは、削減をご覧ください。
computation XlaComputation すべての入力オペランドの各ウィンドウ内の要素に適用される、T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 型のリダクション関数。
window_dimensions ArraySlice<int64> ウィンドウ ディメンション値の整数の配列
window_strides ArraySlice<int64> ウィンドウ ストライド値の整数の配列
base_dilations ArraySlice<int64> ベース拡張値の整数の配列
window_dilations ArraySlice<int64> ウィンドウ拡張値の整数の配列
padding Padding ウィンドウのパディング タイプ(Padding::kSame: ストライドが 1 の場合に入力と同じ出力シェイプになるようにパディングします。Padding::kValid: パディングを使用し、ウィンドウが適合しなくなったら「停止」します)。

ここで

  • N は 1 以上の値にする必要があります。
  • 入力配列はすべて同じディメンションである必要があります。
  • N = 1 の場合、Collate(T)T です。
  • N > 1 の場合、Collate(T_0, ..., T_{N-1})(T0,...T{N-1}) 型の N 要素のタプルです。

次のコードと図は、ReduceWindow の使用例を示しています。入力はサイズ [4x6] のマトリックスで、window_dimensions と window_stride_dimensions の両方が [2x3] です。

// Create a computation for the reduction (maximum).
XlaComputation max;
{
  XlaBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().value();
}

// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    *max,
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

ディメンションのストライドが 1 の場合、ディメンション内のウィンドウの位置は、隣接するウィンドウから 1 要素離れていることを指定します。ウィンドウが重ならないように指定するには、window_stride_dimensions を window_dimensions にする必要があります。次の図は、2 つの異なるストライド値の使用を示しています。パディングは入力の各ディメンションに適用され、計算は、入力がパディング後のディメンションで入力された場合と同じです。

単純でないパディングの例として、入力配列 [10000, 1000, 100, 10, 1] に対して、ディメンション 3 とストライド 2 で、減算ウィンドウの最小値(初期値は MAX_FLOAT)を計算することを考えてみましょう。パディング kValid は、2 つの有効なウィンドウ([10000, 1000, 100][100, 10, 1])の最小値を計算し、出力 [100, 1] を生成します。パディング kSame は、両側に初期要素を追加して [MAX_VALUE, 10000, 1000, 100, 10, 1, MAX_VALUE] を取得することで、まず配列をパディングし、減算ウィンドウ後の形状がストライド 1 の入力と同じになるようにします。パディングされた配列に対して reduce-window を実行すると、3 つのウィンドウ([MAX_VALUE, 10000, 1000][1000, 100, 10][10, 1, MAX_VALUE])で動作し、[1000, 10, 1] が生成されます。

リダクション関数の評価順序は任意で、非確定的になる可能性があります。したがって、減算関数は再結合に過度に敏感であってはなりません。詳細については、Reduce のコンテキストでの結合性に関する説明をご覧ください。

ReplicaId

XlaBuilder::ReplicaId もご覧ください。

レプリカの一意の ID(U32 スカラー)を返します。

ReplicaId()

各レプリカの一意の ID は、[0, N) の範囲内の符号なし整数です。ここで、N はレプリカの数です。すべてのレプリカが同じプログラムを実行しているため、プログラム内の ReplicaId() 呼び出しは、レプリカごとに異なる値を返します。

Reshape

XlaBuilder::ReshapeCollapse オペレーションもご覧ください。

配列のディメンションを新しい構成に変更します。

Reshape(operand, dimensions)

引数 タイプ セマンティクス
operand XlaOp 型 T の配列
dimensions int64 ベクトル 新しいディメンションのサイズのベクトル

概念的には、reshape はまず配列をデータ値の 1 次元ベクトルにフラット化し、次にこのベクトルを新しいシェイプに絞り込みます。入力引数は、型 T の任意の配列、ディメンション インデックスのコンパイル時定数ベクトル、結果のディメンション サイズのコンパイル時定数ベクトルです。dimensions ベクトルによって出力配列のサイズが決まります。dimensions のインデックス 0 の値はディメンション 0 のサイズ、インデックス 1 の値はディメンション 1 のサイズです。dimensions ディメンションの積は、オペランドのディメンション サイズの積と等しくする必要があります。圧縮された配列を dimensions で定義された多次元配列に絞り込む場合、dimensions のディメンションは、変化が最も遅い(最も大きい)ディメンションから変化が最も速い(最も小さい)ディメンションの順に並べ替えられます。

たとえば、v を 24 要素の配列とします。

let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
                    { {20, 21, 22}, {25, 26, 27} },
                    { {30, 31, 32}, {35, 36, 37} },
                    { {40, 41, 42}, {45, 46, 47} } };

let v012_24 = Reshape(v, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47} };

特別なケースとして、reshape は単一要素の配列をスカラーに変換したり、その逆を行ったりできます。次に例を示します。

Reshape(f32[1x1] { {5} }, {}) == 5;
Reshape(5, {1,1}) == f32[1x1] { {5} };

リバース(逆)

XlaBuilder::Rev もご覧ください。

Rev(operand, dimensions)

引数 タイプ セマンティクス
operand XlaOp 型 T の配列
dimensions ArraySlice<int64> 反転するディメンション

指定された dimensions に沿って operand 配列内の要素の順序を逆にして、同じ形状の出力配列を生成します。多次元インデックスの演算子配列の各要素は、変換されたインデックスの出力配列に格納されます。多次元インデックスは、反転する各ディメンションのインデックスを反転することで変換されます(サイズ N のディメンションが反転ディメンションの 1 つである場合、そのインデックス i は N - 1 - i に変換されます)。

Rev 演算の 1 つの用途は、ニューラル ネットワークの勾配計算中に 2 つのウィンドウ ディメンションに沿って畳み込み重み配列を反転することです。

RngNormal

XlaBuilder::RngNormal もご覧ください。

\(N(\mu, \sigma)\) 正規分布に従って生成された乱数を使用して、指定されたシェイプの出力を構築します。パラメータ \(\mu\) と \(\sigma\)、出力シェイプには浮動小数点要素型が必要です。さらに、パラメータはスカラー値である必要があります。

RngNormal(mu, sigma, shape)

引数 タイプ セマンティクス
mu XlaOp 生成された数値の平均を指定する T 型のスカラー
sigma XlaOp 生成された標準偏差を指定する T 型のスカラー
shape Shape 型 T の出力シェイプ

RngUniform

XlaBuilder::RngUniform もご覧ください。

区間 \([a,b)\)で均一分布に従って生成された乱数を使用して、指定された形状の出力を構築します。パラメータと出力要素の型は、ブール型、整数型、浮動小数点型のいずれかであり、型は一貫している必要があります。現在、CPU バックエンドと GPU バックエンドは、F64、F32、F16、BF16、S64、U64、S32、U32 のみをサポートしています。さらに、パラメータはスカラー値である必要があります。 \(b <= a\) の場合、結果は実装定義です。

RngUniform(a, b, shape)

引数 タイプ セマンティクス
a XlaOp 区間の下限を指定する型 T のスカラー
b XlaOp 区間の上限を指定する T 型のスカラー
shape Shape 型 T の出力シェイプ

RngBitGenerator

指定されたアルゴリズム(またはバックエンドのデフォルト)を使用して、均一な乱数ビットで指定されたシェイプの出力を生成し、更新された状態(初期状態と同じシェイプ)と生成された乱数データを返します。

初期状態は、現在の乱数生成の初期状態です。必要なシェイプと有効な値は、使用されるアルゴリズムによって異なります。

出力は初期状態の確定的関数であることが保証されますが、バックエンドと異なるコンパイラ バージョン間で確定的であることは保証されません。

RngBitGenerator(algorithm, key, shape)

引数 タイプ セマンティクス
algorithm RandomAlgorithm 使用する PRNG アルゴリズム。
initial_state XlaOp PRNG アルゴリズムの初期状態。
shape Shape 生成されたデータの出力シェイプ。

algorithm に指定できる値:

散布

XLA スキャッター演算は、入力配列 operands の値である一連の結果を生成します。複数のスライス(scatter_indices で指定されたインデックス)は、update_computation を使用して updates の値のシーケンスで更新されます。

XlaBuilder::Scatter もご覧ください。

scatter(operands..., scatter_indices, updates..., update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

引数 タイプ セマンティクス
operands N 個の XlaOp のシーケンス 分散される T_0, ..., T_N 型の配列 N 個。
scatter_indices XlaOp 分散するスライスの開始インデックスを含む配列。
updates N 個の XlaOp のシーケンス T_0, ..., T_N の N 個の配列。updates[i] には、operands[i] の散乱に使用する値が含まれています。
update_computation XlaComputation 入力配列内の既存の値と、散布時の更新を組み合わせるために使用される計算。この計算の型は T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) にする必要があります。
index_vector_dim int64 開始インデックスを含む scatter_indices 内のディメンション。
update_window_dims ArraySlice<int64> ウィンドウのサイズである updates シェイプのディメンションのセット。
inserted_window_dims ArraySlice<int64> updates シェイプに挿入する必要があるウィンドウのサイズのセット。
scatter_dims_to_operand_dims ArraySlice<int64> スキャッター インデックスからオペランド インデックス空間へのディメンション マップ。この配列は、iscatter_dims_to_operand_dims[i] にマッピングすると解釈されます。1 対 1 で合計である必要があります。
indices_are_sorted bool 呼び出し元によってインデックスが並べ替えられることが保証されているかどうか。
unique_indices bool 呼び出し元によってインデックスが一意であることが保証されているかどうか。

ここで

  • N は 1 以上の値にする必要があります。
  • operands[0]、...、operands[N-1] はすべて同じディメンションにする必要があります。
  • updates[0]、...、updates[N-1] はすべて同じディメンションにする必要があります。
  • N = 1 の場合、Collate(T)T です。
  • N > 1 の場合、Collate(T_0, ..., T_N)T 型の N 要素のタプルです。

index_vector_dimscatter_indices.rank の場合、scatter_indices には末尾に 1 ディメンションがあると暗黙的にみなされます。

ArraySlice<int64> 型の update_scatter_dims は、update_window_dims に含まれない updates シェイプのディメンションのセットとして、昇順で定義します。

scatter の引数は次の制約に従う必要があります。

  • updates 配列には update_window_dims.size + scatter_indices.rank - 1 ディメンションが必要です。

  • updates 配列のディメンション i の境界は、次の要件を満たす必要があります。

    • iupdate_window_dims に存在する場合(つまり、一部の k に対して update_window_dims[k] と等しい場合)、updates のディメンション i の上限は、inserted_window_dims を考慮した後の operand の対応する上限を超えてはなりません(つまり、adjusted_window_bounds[k]。ここで、adjusted_window_bounds には、インデックス inserted_window_dims の上限が削除された operand の上限が含まれています)。
    • iupdate_scatter_dims に存在する場合(つまり、一部の k に対して update_scatter_dims[k] に等しい場合)、updates のディメンション i の上限は、index_vector_dim をスキップして scatter_indices の対応する上限に等しくする必要があります(つまり、k < index_vector_dim の場合は scatter_indices.shape.dims[k]、それ以外の場合は scatter_indices.shape.dims[k+1])。
  • update_window_dims は昇順で、ディメンション番号が重複しておらず、[0, updates.rank) の範囲内である必要があります。

  • inserted_window_dims は昇順で、ディメンション番号が重複しておらず、[0, operand.rank) の範囲内である必要があります。

  • operand.rankupdate_window_dims.sizeinserted_window_dims.size の合計と等しい必要があります。

  • scatter_dims_to_operand_dims.sizescatter_indices.shape.dims[index_vector_dim] と等しく、値は [0, operand.rank) の範囲内である必要があります。

updates 配列内の特定のインデックス U について、この更新を適用する対応する operands 配列内の対応するインデックス I は、次のように計算されます。

  1. G = { update_scatter_dimskU[k] } とします。G を使用して、scatter_indices 配列のインデックス ベクトル S を検索します。S[i] = scatter_indices[Combine(G, i)] となるようにします。ここで、Combine(A, b) は、A の index_vector_dim 位置に b を挿入します。
  2. scatter_dims_to_operand_dims マップを使用して S を散布し、S を使用して operand にインデックス Sin を作成します。より正式には、次のようになります。
    1. k < scatter_dims_to_operand_dims.size の場合、Sin[scatter_dims_to_operand_dims[k]] = S[k]。
    2. Sin[_] = 0(それ以外の場合)。
  3. inserted_window_dims に従って Uupdate_window_dims にインデックスを分散して、各 operands 配列にインデックス Win を作成します。より正式には、次のようになります。
    1. kupdate_window_dims にある場合、Win[window_dims_to_operand_dims(k)] = U[k]。ここで、window_dims_to_operand_dims は、ドメイン [0update_window_dims.size] と範囲 [0operand.rank] \ inserted_window_dims の単調関数です。(たとえば、update_window_dims.size4operand.rank6inserted_window_dims が {02} の場合、window_dims_to_operand_dims は {01132435} です)。
    2. Win[_] = 0(それ以外の場合)。
  4. IWin + Sin です。ここで、+ は要素単位の加算です。

要約すると、散布図オペレーションは次のように定義できます。

  • outputoperands で初期化します。つまり、すべてのインデックス J について、operands[J] 配列内のすべてのインデックス O について、
    output[J][O] = operands[J][O] です。
  • updates[J] 配列内のすべてのインデックス U と、operand[J] 配列内の対応するインデックス O について、Ooutput の有効なインデックスである場合:
    (output[0][O], ..., output[N-1][O]) =update_computation(output[0][O], ..., ,output[N-1][O],updates[0][U], ...,updates[N-1][U])

更新が適用される順序は確定的ではありません。したがって、updates 内の複数のインデックスが operands 内の同じインデックスを参照している場合、output 内の対応する値は不確定になります。

update_computation に渡される最初のパラメータは常に output 配列の現在の値であり、2 番目のパラメータは常に updates 配列の値になります。これは、update_computation可換性がない場合に特に重要です。

indices_are_sorted が true に設定されている場合、XLA は scatter_indices がユーザーによって並べ替えられていると想定できます(scatter_dims_to_operand_dims に従って値を分散した、昇順)。そうでない場合、セマンティクスは実装で定義されます。

unique_indices が true に設定されている場合、XLA は、分散されたすべての要素が一意であると想定できます。したがって、XLA ではアトミック以外のオペレーションを使用できます。unique_indices が true に設定され、分散されるインデックスが一意でない場合、セマンティクスは実装定義です。

非公式には、スキャッター演算はガザー演算のとして見ることができます。つまり、スキャッター演算は、対応するガザー演算によって抽出された入力内の要素を更新します。

詳細な非公式の説明と例については、Gather の「非公式の説明」セクションを参照してください。

選択

XlaBuilder::Select もご覧ください。

述語配列の値に基づいて、2 つの入力配列の要素から出力配列を構築します。

Select(pred, on_true, on_false)

引数 タイプ セマンティクス
pred XlaOp PRED 型の配列
on_true XlaOp 型 T の配列
on_false XlaOp 型 T の配列

配列 on_trueon_false は同じ形状にする必要があります。これは出力配列の形状でもあります。配列 pred は、PRED 要素型で、on_trueon_false と同じ次元数である必要があります。

pred の要素 P ごとに、出力配列の対応する要素は、P の値が true の場合は on_true から、P の値が false の場合は on_false から取得されます。制限付きのブロードキャストとして、predPRED 型のスカラーにできます。この場合、出力配列は、predtrue の場合は on_true から、predfalse の場合は on_false から取得されます。

スカラー以外の pred を使用した例:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

スカラー pred を使用した例:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

タプル間の選択がサポートされています。この目的では、タプルはスカラー型と見なされます。on_trueon_false がタプル(同じシェイプである必要があります)の場合、predPRED 型のスカラーである必要があります。

SelectAndScatter

XlaBuilder::SelectAndScatter もご覧ください。

このオペレーションは、まず operand 配列で ReduceWindow を計算して各ウィンドウから要素を選択し、次に選択した要素のインデックスに source 配列を散布して、オペランド配列と同じ形状の出力配列を構築する複合オペレーションと見なすことができます。バイナリ select 関数は、各ウィンドウに適用して各ウィンドウから要素を選択するために使用されます。この関数は、最初のパラメータのインデックス ベクトルが 2 番目のパラメータのインデックス ベクトルよりも辞書順で小さいプロパティで呼び出されます。select 関数は、最初のパラメータが選択された場合は true を返し、2 番目のパラメータが選択された場合は false を返します。また、選択された要素が特定のウィンドウで走査される要素の順序に依存しないように、関数は推移性を保持する必要があります(つまり、select(a, b)select(b, c)true であれば、select(a, c)true である必要があります)。

関数 scatter は、出力配列内の選択した各インデックスに適用されます。2 つのスカラー パラメータを受け取ります。

  1. 出力配列の選択したインデックスの現在の値
  2. 選択したインデックスに適用される source の散布値

2 つのパラメータを結合し、出力配列内の選択したインデックスの値を更新するために使用されるスカラー値を返します。最初は、出力配列のすべてのインデックスが init_value に設定されます。

出力配列は operand 配列と同じシェイプで、source 配列は operand 配列に ReduceWindow 演算を適用した結果と同じシェイプにする必要があります。SelectAndScatter は、ニューラル ネットワーク内のプーリング レイヤの勾配値をバックプロパゲートするために使用できます。

SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)

引数 タイプ セマンティクス
operand XlaOp ウィンドウがスライドする T 型の配列
select XlaComputation T, T -> PRED 型のバイナリ計算。各ウィンドウ内のすべての要素に適用されます。最初のパラメータが選択されている場合は true を返し、2 番目のパラメータが選択されている場合は false を返します。
window_dimensions ArraySlice<int64> ウィンドウ ディメンション値の整数の配列
window_strides ArraySlice<int64> ウィンドウ ストライド値の整数の配列
padding Padding ウィンドウのパディング タイプ(Padding::kSame または Padding::kValid)
source XlaOp 散布する値を含む T 型の配列
init_value XlaOp 出力配列の初期値の型 T のスカラー値
scatter XlaComputation T, T -> T 型のバイナリ計算。各散布図のソース要素を宛先要素に適用します。

次の図は、SelectAndScatter の使用例を示しています。select 関数は、パラメータの最大値を計算します。以下の図(2)のようにウィンドウが重複している場合、operand 配列のインデックスが異なるウィンドウによって複数回選択される可能性があります。この図では、値 9 の要素が上のウィンドウ(青と赤)の両方によって選択され、バイナリ加算 scatter 関数によって値 8(2 + 6)の出力要素が生成されます。

scatter 関数の評価順序は任意で、非確定的になる場合があります。したがって、scatter 関数は再アソシエーションに過度に敏感であってはなりません。詳細については、Reduce のコンテキストでの結合性に関する説明をご覧ください。

送信

XlaBuilder::Send もご覧ください。

Send(operand, channel_handle)

引数 タイプ セマンティクス
operand XlaOp 送信するデータ(T 型の配列)
channel_handle ChannelHandle 送受信ペアごとの一意の識別子

指定されたオペランド データを、同じチャネル ハンドルを共有する別の計算の Recv 命令に送信します。データを返しません。

Recv オペレーションと同様に、Send オペレーションのクライアント API は同期通信を表し、内部的には 2 つの HLO 命令(SendSendDone)に分解され、非同期データ転送を可能にします。HloInstruction::CreateSendHloInstruction::CreateSendDone もご覧ください。

Send(HloInstruction operand, int64 channel_id)

同じチャンネル ID を持つ Recv 命令によって割り振られたリソースへのオペランドの非同期転送を開始します。コンテキストを返します。これは、次の SendDone 命令で使用され、データ転送の完了を待機します。コンテキストは {オペランド(シェイプ)、リクエスト ID(U32)} のタプルであり、SendDone 命令でのみ使用できます。

SendDone(HloInstruction context)

Send 命令によって作成されたコンテキストを受け取り、データ転送が完了するまで待機します。この命令はデータを返しません。

チャンネル インストラクションのスケジュール設定

各チャンネルの 4 つの命令(RecvRecvDoneSendSendDone)の実行順序は次のとおりです。

  • RecvSend の前に発生します。
  • SendRecvDone の前に発生します。
  • RecvRecvDone の前に発生します。
  • SendSendDone の前に発生します。

バックエンド コンパイラがチャンネル命令を介して通信する計算ごとに線形スケジュールを生成する場合は、計算間でサイクルが発生しないようにする必要があります。たとえば、次のスケジュールはデッドロックにつながります。

命令の制約は、実行時に TPU にのみ適用されます。GPU では、sendrecv はブロックされ、送信元デバイスとターゲット デバイス間のハンドシェイクが完了するまで実際のデータは送信されません。

スライス

XlaBuilder::Slice もご覧ください。

スライスでは、入力配列からサブ配列を抽出します。サブ配列は入力と同じ数のディメンションがあり、入力配列内の境界ボックス内の値が含まれます。境界ボックスのディメンションとインデックスは、スライス オペレーションの引数として指定されます。

Slice(operand, start_indices, limit_indices, strides)

引数 タイプ セマンティクス
operand XlaOp 型 T の N 次元配列
start_indices ArraySlice<int64> 各ディメンションのスライスの開始インデックスを含む N 個の整数のリスト。値は 0 以上である必要があります。
limit_indices ArraySlice<int64> 各ディメンションのスライスの終了インデックス(除く)を含む N 個の整数のリスト。各値は、ディメンションのそれぞれの start_indices 値以上で、ディメンションのサイズ以下でなければなりません。
strides ArraySlice<int64> スライスの入力ストライドを決定する N 個の整数のリスト。スライスは、ディメンション d 内のすべての strides[d] 要素を選択します。

1 次元の例:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

2 次元の例:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

並べ替え

XlaBuilder::Sort もご覧ください。

Sort(operands, comparator, dimension, is_stable)

引数 タイプ セマンティクス
operands ArraySlice<XlaOp> 並べ替えるオペランド。
comparator XlaComputation 使用する比較演算。
dimension int64 並べ替えるディメンション。
is_stable bool 安定した並べ替えを使用するかどうか。

オペランドが 1 つだけ指定されている場合:

  • オペランドが 1 次元テンソル(配列)の場合、結果は並べ替えられた配列になります。配列を昇順で並べ替える場合、比較演算子は小なり比較を実行する必要があります。正式には、配列が並べ替えられた後、i < jcomparator(value[i], value[j]) = comparator(value[j], value[i]) = false または comparator(value[i], value[j]) = true であるすべてのインデックス位置 i, j で保持されます。

  • オペランドのディメンション数が大きい場合は、指定されたディメンションに沿ってオペランドが並べ替えられます。たとえば、2 次元テンソル(行列)の場合、ディメンション値が 0 であればすべての列が個別に並べ替えられ、ディメンション値が 1 であれば各行が個別に並べ替えられます。ディメンション番号が指定されていない場合、デフォルトで最後のディメンションが選択されます。並べ替えられたディメンションには、1 次元の場合と同じ並べ替え順序が適用されます。

n > 1 オペランドが指定されている場合:

  • すべての n オペランドは、同じディメンションのテンソルである必要があります。テンソルの要素型は異なる場合があります。

  • すべてのオペランドは個別ではなく、一緒に並べ替えられます。概念的には、オペランドはタプルとして扱われます。インデックス位置 ij の各オペランドの要素を入れ替える必要があるかどうかを確認するときに、2 * n スカラー パラメータを使用して比較演算子が呼び出されます。パラメータ 2 * kk-th オペランドの位置 i の値に対応し、パラメータ 2 * k + 1k-th オペランドの位置 j の値に対応します。通常、比較オペレーターはパラメータ 2 * k2 * k + 1 を比較し、必要に応じて他のパラメータペアをタイブレークとして使用します。

  • 結果は、(上記のように指定されたディメンションに沿って)並べ替えられたオペランドで構成されるタプルです。タプルの i-th オペランドは、Sort の i-th オペランドに対応しています。

たとえば、3 つのオペランド operand0 = [3, 1]operand1 = [42, 50]operand2 = [-3.0, 1.1] があり、比較演算子が operand0 の値のみを小なりで比較する場合、並べ替えの出力はタプル ([1, 3], [50, 42], [1.1, -3.0]) になります。

is_stable が true に設定されている場合、並べ替えは安定することが保証されます。つまり、比較演算子によって等しいと見なされる要素がある場合、等しい値の相対順序は保持されます。2 つの要素 e1e2 は、comparator(e1, e2) = comparator(e2, e1) = false の場合にのみ等しくなります。デフォルトでは、is_stable は false に設定されています。

トップ K

XlaBuilder::TopK もご覧ください。

TopK は、指定されたテンソルの最後のディメンションの最大または最小の要素の値とインデックスを検索します。k

TopK(operand, k, largest)

引数 タイプ セマンティクス
operand XlaOp 上位 k 要素を抽出するテンソル。テンソルの次元は 1 以上である必要があります。テンソルの最後のディメンションのサイズは k 以上にする必要があります。
k int64 抽出する要素の数。
largest bool 最大の k 要素と最小の k 要素のどちらを抽出するか。

1 次元の入力テンソル(配列)の場合、配列内の最大または最小のエントリを見つけて、2 つの配列の (values, indices) タプルを出力します。kしたがって、values[j]operandj 番目に大きい/小さいエントリであり、そのインデックスは indices[j] です。

1 つ以上のディメンションを持つ入力テンソルの場合、最後のディメンションに沿って上位 k エントリを計算し、出力の他のすべてのディメンション(行)を保持します。したがって、Q >= k のシェイプ [A, B, ..., P, Q] のオペランドの場合、出力はタプル (values, indices) です。ここで、

values.shape = indices.shape = [A, B, ..., P, k]

行内の 2 つの要素が等しい場合、インデックスが小さい要素が先に表示されます。

行 / 列の入れ替え

tf.reshape オペレーションもご覧ください。

Transpose(operand)

引数 タイプ セマンティクス
operand XlaOp 転置するオペランド。
permutation ArraySlice<int64> ディメンションを並べ替える方法。

オペランドのディメンションを指定された順序で並べ替えます(∀ i . 0 ≤ i < number of dimensions ⇒ input_dimensions[permutation[i]] = output_dimensions[i])。

これは、Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) と同じです。

TriangularSolve

XlaBuilder::TriangularSolve もご覧ください。

下三角または上三角係数行列を持つ連立一次方程式を、順方向または逆方向の代入によって解きます。先頭のディメンションに沿ってブロードキャストするこのルーティンは、ab が指定された変数 x について、行列システム op(a) * x = b または x * op(a) = b のいずれかを解きます。ここで、op(a)op(a) = aop(a) = Transpose(a)、または op(a) = Conj(Transpose(a)) です。

TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)

引数 タイプ セマンティクス
a XlaOp 形状が [..., M, M] の複素数型または浮動小数点型の 3 次元以上の配列。
b XlaOp left_side が true の場合 [..., M, K]、それ以外の場合は [..., K, M] の形状を持つ、同じ型の 3 次元以上の配列。
left_side bool op(a) * x = b 形式(true)または x * op(a) = b 形式(false)のシステムを解くかどうかを示します。
lower bool a の上部または下部の三角形を使用するかどうか。
unit_diagonal bool true の場合、a の対角要素は 1 と見なされ、アクセスされません。
transpose_a Transpose a をそのまま使用するか、転置するか、共役転置を取るか。

入力データは、lower の値に応じて、a の下部または上部の三角形からのみ読み取られます。他の三角形の値は無視されます。出力データは同じ三角形で返されます。他の三角形の値は実装定義であり、任意の値にすることができます。

ab のディメンション数が 2 より大きい場合、これらは行列のバッチとして扱われます。この場合、マイナーな 2 つのディメンション以外のすべてのディメンションがバッチ ディメンションになります。ab のバッチ ディメンションは同じである必要があります。

タプル

XlaBuilder::Tuple もご覧ください。

可変数の各データハンドルを含むタプル。各データハンドルには独自のシェイプがあります。

これは C++ の std::tuple に似ています。コンセプト的には次のようになります。

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

タプルは、GetTupleElement オペレーションを介してデコンストラクト(アクセス)できます。

一方

XlaBuilder::While もご覧ください。

While(condition, body, init)

引数 タイプ セマンティクス
condition XlaComputation ループの終了条件を定義する T -> PRED タイプの XlaComputation。
body XlaComputation ループの本文を定義する T -> T タイプの XlaComputation。
init T conditionbody のパラメータの初期値。

condition が失敗するまで body を順番に実行します。これは、他の多くの言語の一般的な while ループと似ていますが、以下に示す違いと制限があります。

  • While ノードは、body の最後の実行結果である T 型の値を返します。
  • T の形状は静的に決定され、すべての反復処理で同じである必要があります。

計算の T パラメータは、最初の反復処理で init 値で初期化され、その後の反復処理で body の新しい結果に自動的に更新されます。

While ノードの主なユースケースの 1 つは、ニューラル ネットワークでのトレーニングの繰り返し実行を実装することです。以下に、計算を表すグラフとともに、簡素化された疑似コードを示します。コードは while_test.cc にあります。この例の型 T は、反復回数用の int32 と累積用の vector[10] で構成される Tuple です。1, 000 回の反復処理で、ループは定数ベクトルを累積器に追加し続けます。

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}