このドキュメントでは、XLA のブロードキャスト セマンティクスについて説明します。
ブロードキャストとは
ブロードキャストとは、さまざまな形状を持つ配列を算術演算に互換性のある形状にするプロセスです。用語は NumPy のブロードキャストから借用しています。
異なるランクの多次元配列間のオペレーションや、異なる形状の互換性を持つ多次元配列間のオペレーションには、ブロードキャストが必要になることがあります。加算 X+v
について考えてみましょう。ここで、X
は行列(ランク 2 の配列)、v
はベクトル(ランク 1 の配列)です。要素単位の加算を行うために、XLA は v
を一定回数複製して、ベクトル v
を行列 X
と同じランクに「ブロードキャスト」する必要があります。ベクトルの長さは、行列の少なくとも 1 つの次元と一致する必要があります。
次に例を示します。
|1 2 3| + |7 8 9|
|4 5 6|
行列の次元は(2,3)で、ベクトルの次元は(3)です。ベクトルは、行に複製されて以下を取得することでブロードキャストされます。
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
原則
XLA 言語は可能な限り厳格かつ明示的であり、暗黙的な「魔法」の機能を使用しません。このような機能により、一部の計算の定義は若干簡単になる可能性がありますが、ユーザーコードに組み込まれている前提条件が増えると、長期的には変更が困難になります。必要に応じて、クライアント レベルのラッパーに暗黙的なマジック機能を追加できます。
ブロードキャストに関して、XLA では、ランクが異なる配列間の演算に関して明示的なブロードキャスト仕様が必要です。これは、可能な場合に仕様を推測する NumPy とは異なります。
低ランク配列を高ランク配列にブロードキャストする
スカラーは、ブロードキャスト ディメンションの明示的な指定なしに、常に配列でブロードキャストできます。スカラーと配列間の要素単位バイナリ演算とは、スカラーを使用した演算を配列内の各要素に適用することを意味します。たとえば、行列にスカラーを追加すると、各要素がスカラーと入力行列の対応する要素の合計である行列が生成されます。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
ブロードキャストのほとんどのニーズは、バイナリ演算でディメンションのタプルを使用することで対応できます。演算への入力のランクが異なる場合、このブロードキャスト タプルは、高ランク配列のどのディメンションを低ランク配列と一致させるかを指定します。
前の例で考えてみましょう。(2,3)行列にスカラーを加算する代わりに、次元(3)のベクトルを次元(2,3)の行列に追加します。ブロードキャストを指定していない場合、この操作は無効です。行列ベクトルの加算を正しくリクエストするには、ブロードキャストの次元を(1)に指定します。つまり、ベクトルの次元が行列の次元 1 と一致します。2D で、次元 0 が行を表し、次元 1 が列を表す場合、ベクトルの各要素は、行列内の行数と一致するサイズの列になります。
|7 8 9| ==> |7 8 9|
|7 8 9|
より複雑な例として、3 要素ベクトル(次元(3))を 3x3 行列(次元(3,3))に加算してみましょう。この例では、次の 2 つの方法でブロードキャストが行われます。
(1)1 のブロードキャスト ディメンションを使用できる。各ベクトル要素は列になり、行列内の各行に対してベクトルが複製されます。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2)0 のブロードキャスト ディメンションを使用できる。各ベクトル要素は行になり、ベクトルは行列内の各列で複製されます。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
ブロードキャスト ディメンションは、小さいランクの形がより大きなランクの形にブロードキャストされる仕組みを記述するタプルになります。たとえば、2x3x4 の立方体と 3x4 の行列が与えられた場合、ブロードキャスト タプル (1,2) は行列を立方体の次元 1 と 2 に一致させることを意味します。
このタイプのブロードキャストは、broadcast_dimensions
引数が指定されている場合、XlaBuilder
のバイナリ演算で使用されます。たとえば、XlaBuilder::Add をご覧ください。XLA ソースコードでは、このタイプのブロードキャストは「InDim」ブロードキャストと呼ばれることもあります。
正式な定義
ブロードキャスト属性を使用すると、一致させる上位配列のディメンションを指定することで、下位配列と上位配列を一致させることができます。たとえば、次元が MxNxPxQ の配列の場合、次元 T のベクトルは次のようにマッチングできます。
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
いずれの場合も、T は上位配列の一致する次元と等しくなければなりません。ベクトルの値は、一致したディメンションから他のすべてのディメンションにブロードキャストされます。
TxV 行列を MxNxPxQ 配列に一致させるために、ブロードキャスト ディメンションのペアが使用されます。
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
ブロードキャスト タプルの次元の順序は、下位配列の次元が上位配列の次元と一致すると想定される順序にする必要があります。タプルの 1 つ目の要素は、高ランク配列のどのディメンションが低ランク配列のディメンション 0 と一致する必要があるかを指定します。タプルの 2 番目の要素は、高ランク配列のどのディメンションを低ランク配列のディメンション 1 と一致させる必要があるかを指定します。ブロードキャスト ディメンションの順序は厳密に増加する必要があります。たとえば前の例では、V を N に、T を P にマッチングすることは違法ですが、V を P と N の両方にマッチングさせることも不正です。
縮退ディメンションを含む類似ランク配列のブロードキャスト
関連する問題は、ランクが同じでディメンション サイズが異なる 2 つの配列をブロードキャストすることです。NumPy と同様に、これは配列に互換性がある場合にのみ可能です。すべてのディメンションに互換性がある場合、2 つの配列は互換性があります。次の場合、2 つのディメンションは互換性があります。
- 等しい、または
- そのうちの 1 つは 1(「縮退」ディメンション)です。
互換性のある配列が 2 つ検出された場合、結果の形状には各ディメンション インデックスで最大 2 つの入力があります。
例:
- (2,1) と (2,3) は、(2,3) にブロードキャストされます。
- (1,2,5) と (7,2,5) は (7,2,5) にブロードキャストされます。
- (7,2,5) と (7,1,5) は (7,2,5) にブロードキャストされます。
- (7,2,5) と (7,2,6) は互換性がないためブロードキャストできません。
特殊なケースもサポートされています。その場合、各入力配列の異なるインデックスに縮退ディメンションがあります。この場合、結果は (2,1) と (1,3) が (2,3) にブロードキャストされる「外部演算」となります。その他の例については、ブロードキャストに関する NumPy のドキュメントをご覧ください。
ブロードキャストの構成
低ランク配列の高ランク配列へのブロードキャストと、縮退ディメンションを使用したブロードキャストは、どちらも同じバイナリ演算で実行できます。たとえば、サイズ 4 のベクトルとサイズ 1x2 の行列は、値(0)のブロードキャスト ディメンションを使用して加算できます。
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
まず、ブロードキャスト ディメンションを使用して、ベクトルがランク 2(行列)までブロードキャストされます。ブロードキャスト ディメンションの単一の値(0)は、ベクトルのディメンション 0 がマトリックスのディメンション 0 と一致することを示します。これにより、サイズ 4×M の行列が生成されます。この行列では、値 M が 1×2 配列内の対応するディメンション サイズと一致するように選択されます。したがって、4x2 の行列は次のようになります。
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
次に、「縮退次元のブロードキャスト」により、右側の対応するディメンション サイズと一致するように 1×2 マトリックスのディメンション 0 がブロードキャストされます。
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
より複雑な例は、ブロードキャスト次元(1, 2)を使用して、サイズ 4x3x1 の配列に追加されたサイズ 1x2 の行列です。まず、ブロードキャスト ディメンションを使用して 1 x 2 行列がランク 3 までブロードキャストされ、中間 Mx1x2 配列が生成されます。この配列では、ディメンション サイズ M は大きい方のオペランド(4x3x1 配列)のサイズによって決まり、4x1x2 の中間配列が生成されます。M はディメンション 0(左端のディメンション)にあります。ブロードキャストのディメンションが (1, 2) であるため、ディメンション 1 と 2 は元の 1x2 マトリックスのディメンションにマッピングされます。この中間配列を縮退次元のブロードキャストを使用して 4 × 3 × 1 の行列に追加し、4 × 3 × 2 の配列結果を生成できます。