このドキュメントでは、XLA のブロードキャスト セマンティクスについて説明します。
ブロードキャストとは
ブロードキャストとは、形状が異なる配列を算術演算に対して互換性のある形状にするプロセスです。用語は NumPy のブロードキャストから借用しています。
ブロードキャストは、異なるランクの多次元配列間のオペレーション、または異なるが互換性のある形状を持つ多次元配列間のオペレーションで必要になる場合があります。X+v
の加算について考えてみましょう。ここで、X
は行列(ランク 2 の配列)、v
はベクトル(ランク 1 の配列)です。要素ごとの加算を行うために、XLA は v
を特定の回数複製して、ベクトル v
を行列 X
と同じランクに「ブロードキャスト」する必要があります。ベクトルの長さは、行列の少なくとも 1 つの次元と一致する必要があります。
次に例を示します。
|1 2 3| + |7 8 9|
|4 5 6|
行列の次元は(2,3)で、ベクトルの次元は(3)です。ベクトルは、行の上に複製してブロードキャストされ、以下を取得します。
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
原則
XLA 言語は可能な限り厳格で明示的であり、暗黙の「魔法」機能を使用しません。このような機能によって、一部の計算は定義が若干容易になる場合がありますが、その代償としてユーザーコードに組み込まれている前提条件が増え、長期的には変更が困難になります。必要に応じて、クライアント レベルのラッパーに暗黙的なマジック機能を追加できます。
ブロードキャストに関して、XLA では、異なるランクの配列間の演算に明示的なブロードキャスト仕様が必要です。これは可能な場合に仕様を推測する NumPy とは異なります
低ランク配列を高ランク配列にブロードキャストする
スカラーは、ブロードキャスト ディメンションの明示的な指定なしに、常に配列でブロードキャストできます。スカラーと配列の間の要素ごとのバイナリ演算とは、スカラーを持つ演算を配列内の各要素に適用することを意味します。たとえば、行列にスカラーを追加すると、各要素がスカラーと入力行列の対応する要素の合計である行列が生成されます。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
ブロードキャストのほとんどのニーズは、バイナリ演算でディメンションのタプルを使用することで対応できます。演算への入力のランクが異なる場合、このブロードキャスト タプルは、上位ランク配列のどのディメンションを下位ランク配列と一致させるかを指定します。
前の例で考えてみましょう。スカラーを(2,3)行列に加える代わりに、次元(3)のベクトルを次元(2,3)の行列に追加します。ブロードキャストを指定していない場合、この操作は無効になります。行列ベクトルの加算を正しくリクエストするには、ブロードキャストの次元を(1)に指定します。これは、ベクトルの次元が行列の次元 1 と一致することを意味します。2D で、次元 0 が行を表し、次元 1 が列を表す場合、ベクトルの各要素は、行列内の行数と一致するサイズの列になります。
|7 8 9| ==> |7 8 9|
|7 8 9|
より複雑な例として、3 要素ベクトル(次元(3))を 3x3 行列(次元(3,3))に追加する場合を考えてみましょう。この例では、次の 2 つの方法でブロードキャストが行われます。
(1)1 のブロードキャスト ディメンションを使用できる。各ベクトル要素は 1 つの列になり、ベクトルは行列内の各行に複製されます。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2)0 のブロードキャスト ディメンションを使用できる。各ベクトル要素は行になり、ベクトルは行列内の各列で複製されます。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
ブロードキャスト ディメンションは、小さいランク形態がより大きなランク形態にブロードキャストされる方法を記述するタプルにできます。たとえば、2x3x4 の立方体と 3x4 の行列の場合、ブロードキャスト タプル (1,2) は行列を立方体の次元 1 と 2 に一致させることを意味します。
このタイプのブロードキャストは、broadcast_dimensions
引数が指定されている場合、XlaBuilder
のバイナリ演算で使用されます。たとえば、XlaBuilder::Add をご覧ください。XLA ソースコードでは、このタイプのブロードキャストは「InDim」ブロードキャストと呼ばれることもあります。
正式な定義
ブロードキャスト属性を使用すると、照合する高ランク配列の次元を指定することで、低ランク配列と高ランク配列をマッチングできます。たとえば、次元が MxNxPxQ の配列の場合、次元 T のベクトルは次のようにマッチングできます。
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
いずれの場合も、T は上位配列の一致する次元と等しくなければなりません。ベクトルの値は、一致したディメンションから他のすべてのディメンションにブロードキャストされます。
TxV 行列を MxNxPxQ 配列に一致させるために、ブロードキャスト次元のペアが使用されます。
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
ブロードキャスト タプルのディメンションの順序は、低ランク配列のディメンションが高ランク配列のディメンションと一致すると予想される順序にする必要があります。タプルの最初の要素は、高ランク配列のどのディメンションが低ランク配列のディメンション 0 と一致する必要があるかを指定します。タプルの 2 番目の要素は、高ランク配列のどのディメンションが低ランク配列のディメンション 1 と一致する必要があるかを指定します。ブロードキャスト ディメンションの順序は厳密な昇順にする必要があります。たとえば、前の例では、V を N に、T を P にマッチングさせることは不適切です。V を P と N の両方にマッチングさせることも違法です。
縮退次元を持つ類似ランク配列のブロードキャスト
関連する問題は、ランクが同じでディメンション サイズが異なる 2 つの配列をブロードキャストすることです。NumPy と同様に、これは配列に「互換性」がある場合にのみ可能です。すべての次元に互換性がある場合、2 つの配列には互換性があります。次の場合、2 つのディメンションは互換性があります。
- 等しい、または
- そのうちの一つが 1(「縮退」次元)
互換性のある配列が 2 つ検出された場合、結果の形状は各ディメンション インデックスで最大 2 つの入力を持ちます。
例:
- (2,1) と (2,3) は (2,3) にブロードキャストされます。
- (1,2,5) と (7,2,5) は (7,2,5) にブロードキャストされます。
- (7,2,5) と (7,1,5) は (7,2,5) にブロードキャストされます。
- (7,2,5) と (7,2,6) は互換性がないため、ブロードキャストできません。
各入力配列が異なるインデックスの縮退ディメンションを持つ特殊なケースもサポートされています。この場合、結果は「外部演算」になります。つまり、(2,1) と (1,3) は (2,3) にブロードキャストされます。その他の例については、ブロードキャストに関する NumPy のドキュメントをご覧ください。
配信の構成
低ランク配列の高ランク配列へのブロードキャストと、縮退ディメンションを使用したブロードキャストは、どちらも同じバイナリ演算で実行できます。たとえば、サイズ 4 のベクトルとサイズ 1x2 の行列は、値(0)のブロードキャスト ディメンションを使用して加算できます。
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
まず、ブロードキャスト ディメンションを使用して、ベクトルがランク 2(行列)までブロードキャストされます。ブロードキャスト ディメンションの単一の値(0)は、ベクトルのディメンション 0 がマトリックスのディメンション 0 と一致することを示します。これにより、サイズが 4×M の行列が生成されます。ここでは、1×2 配列内の対応する次元サイズと一致するように値 M が選択されます。したがって、4x2 行列は次のようになります。
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
次に、「縮退次元ブロードキャスト」は、右側の対応する次元サイズと一致するように 1 × 2 行列の次元 0 をブロードキャストします。
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
より複雑な例としては、ブロードキャスト ディメンション(1, 2)を使用して、サイズ 4x3x1 の配列に追加されたサイズ 1x2 の行列があります。まず、ブロードキャスト ディメンションを使用して 1x2 行列がランク 3 までブロードキャストされ、中間 Mx1x2 配列が生成されます。この配列サイズ M は大きい方のオペランド(4x3x1 配列)のサイズによって決まり、4x1x2 中間配列が生成されます。ブロードキャストのディメンションが(1, 2)であるため、ディメンション 1 と 2 が元の 1x2 マトリックスのディメンションにマッピングされるため、M はディメンション 0(左端のディメンション)にあります。この中間配列を、縮退次元のブロードキャストを使用して 4x3x1 行列に追加して、4x3x2 配列結果を生成できます。