このドキュメントでは、XLA のブロードキャスト セマンティクスについて説明します。
ブロードキャストとは
ブロードキャストは、異なる形状の配列を算術演算で互換性のある形状にするプロセスです。この用語は NumPy ブロードキャストから借用されています。
ブロードキャストは、ランクの異なる多次元配列間、または形状は異なるが互換性のある多次元配列間のオペレーションで必要になることがあります。X が行列(2 次元配列)、v がベクトル(1 次元配列)である加算 X+v について考えてみましょう。要素ごとの加算を実行するには、XLA は v を一定回数複製して、ベクトル v を行列 X と同じ次元数に「ブロードキャスト」する必要があります。ベクトルの長さは、行列のディメンションの少なくとも 1 つと一致している必要があります。
次に例を示します。
|1 2 3| + |7 8 9|
|4 5 6|
行列のディメンションは(2,3)で、ベクトルのディメンションは(3)です。ベクトルは、行全体に複製されてブロードキャストされ、次のようになります。
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
原則
XLA 言語は可能な限り厳密で明示的であり、暗黙的な「魔法」のような機能は避けています。このような機能を使用すると、一部の計算の定義が若干容易になりますが、ユーザーコードに組み込まれる仮定が増え、長期的に変更が難しくなるというデメリットがあります。必要に応じて、クライアントレベルのラッパーに暗黙的な魔法の機能を追加できます。
ブロードキャストに関して、XLA では、ランクの異なる配列間のオペレーションで明示的なブロードキャスト仕様が必要です。これは、可能な場合に仕様を推測する NumPy とは異なります。
低次元配列を高次元配列にブロードキャストする
スカラーは、ブロードキャスト ディメンションを明示的に指定しなくても、常に配列にブロードキャストできます。スカラーと配列の要素ごとのバイナリ オペレーションは、スカラーを使用して配列内の各要素にオペレーションを適用することを意味します。たとえば、行列にスカラーを追加すると、各要素がスカラーと入力行列の対応する要素の合計である行列が生成されます。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
ほとんどのブロードキャスト ニーズは、バイナリ オペレーションのディメンションのタプルを使用してキャプチャできます。オペレーションの入力のランクが異なる場合、このブロードキャスト タプルは、高次元配列のどの次元を低次元配列と一致させるかを指定します。
前の例で考えてみましょう。スカラーを(2,3)行列に追加する代わりに、(3)次元のベクトルを(2,3)次元の行列に追加します。ブロードキャストを指定しないと、このオペレーションは無効になります。行列とベクトルの加算を正しくリクエストするには、ブロードキャスト ディメンションを(1)に指定します。これは、ベクトルのディメンションが行列のディメンション 1 と一致することを意味します。2D で、ディメンション 0 が行を表し、ディメンション 1 が列を表す場合、ベクトルの各要素は、行列の行数と一致するサイズの列になります。
|7 8 9| ==> |7 8 9|
|7 8 9|
より複雑な例として、3 要素のベクトル(次元(3))を 3x3 行列(次元(3,3))に追加することを考えます。この例では、ブロードキャストは次の 2 つの方法で行われます。
(1)ブロードキャスト ディメンション 1 を使用できます。各ベクトル要素が列になり、ベクトルは行列の各行に対して複製されます。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2)ブロードキャスト ディメンション 0 を使用できます。各ベクトル要素は行になり、ベクトルは行列の各列に対して複製されます。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
ブロードキャスト ディメンションは、低次元の形状を高次元の形状にブロードキャストする方法を記述するタプルにできます。たとえば、2x3x4 の直方体と 3x4 の行列が与えられた場合、ブロードキャスト タプル(1,2)は、行列を直方体のディメンション 1 と 2 に一致させることを意味します。
このタイプのブロードキャストは、broadcast_dimensions 引数が指定されている場合、XlaBuilder のバイナリ演算で使用されます。たとえば、XlaBuilder::Add をご覧ください。XLA ソースコードでは、このタイプのブロードキャストは「InDim」ブロードキャストと呼ばれることがあります。
正式な定義
ブロードキャスト属性を使用すると、高次元配列のどの次元を照合するかを指定することで、低次元配列を高次元配列に照合できます。たとえば、ディメンション MxNxPxQ の配列の場合、ディメンション T のベクトルは次のように照合できます。
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
いずれの場合も、T は高次元配列の対応するディメンションと等しくなければなりません。次に、ベクトルの値が一致するディメンションから他のすべてのディメンションにブロードキャストされます。
TxV 行列を MxNxPxQ 配列に一致させるには、ブロードキャスト ディメンションのペアを使用します。
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
ブロードキャスト タプルのディメンションの順序は、低次元配列のディメンションが高次元配列のディメンションと一致することが想定される順序である必要があります。タプルの最初の要素は、高次元配列のどの次元が低次元配列の次元 0 と一致する必要があるかを指定します。タプルの 2 番目の要素は、高次元配列のどの次元が低次元配列の次元 1 と一致する必要があるかを指定します。ブロードキャスト ディメンションの順序は厳密に昇順である必要があります。たとえば、前の例では、V を N に、T を P に一致させることはできません。また、V を P と N の両方に一致させることもできません。
縮退したディメンションを持つ類似したディメンションの配列のブロードキャスト
関連する問題として、ディメンションの数は同じでもディメンションのサイズが異なる 2 つの配列をブロードキャストすることがあります。NumPy と同様に、これは配列が互換性のある場合にのみ可能です。2 つの配列は、すべてのディメンションが互換性がある場合に互換性があります。2 つのディメンションが互換性を持つのは、次の条件を満たす場合です。
- 等しい場合
- そのうちの 1 つは 1(「縮退」次元)です。
互換性のある 2 つの配列が検出されると、結果の形状はすべてのディメンション インデックスで 2 つの入力の最大値になります。
例:
- (2,1) と (2,3) は (2,3) にブロードキャストします。
- (1,2,5) と (7,2,5) は (7,2,5) にブロードキャストされます。
- (7,2,5) と (7,1,5) が (7,2,5) にブロードキャストされます。
- (7,2,5)と(7,2,6)は互換性がないため、ブロードキャストできません。
入力配列のそれぞれが異なるインデックスに縮退次元を持つ特殊なケースも発生しますが、これもサポートされています。この場合、結果は「外部オペレーション」になります。つまり、(2,1)と(1,3)が(2,3)にブロードキャストされます。他の例については、ブロードキャストに関する NumPy ドキュメントをご覧ください。
ブロードキャストの構成
低次元配列から高次元配列へのブロードキャストと、縮退次元を使用したブロードキャストは、どちらも同じバイナリ オペレーションで実行できます。たとえば、サイズ 4 のベクトルとサイズ 1x2 の行列は、値(0)のブロードキャスト ディメンションを使用して加算できます。
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
まず、ブロードキャスト ディメンションを使用して、ベクトルが 2 次元(行列)までブロードキャストされます。ブロードキャスト ディメンションの単一の値(0)は、ベクトルのディメンション 0 が行列のディメンション 0 と一致することを示します。これにより、サイズ 4xM の行列が生成されます。ここで、M の値は 1x2 配列の対応するディメンション サイズと一致するように選択されます。したがって、4x2 行列が生成されます。
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
次に、「縮退ディメンション ブロードキャスト」は、1x2 行列のディメンション ゼロをブロードキャストして、右辺の対応するディメンション サイズと一致させます。
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
より複雑な例として、ブロードキャスト ディメンション (1, 2) を使用して、サイズ 1x2 の行列をサイズ 4x3x1 の配列に追加する例があります。まず、ブロードキャスト ディメンションを使用して 1x2 行列を 3 次元までブロードキャストし、中間 Mx1x2 配列を生成します。ここで、ディメンション サイズ M は、大きいオペランド(4x3x1 配列)のサイズによって決定され、4x1x2 中間配列が生成されます。ディメンション 1 と 2 は元の 1x2 行列のディメンションにマッピングされ、ブロードキャスト ディメンションは(1, 2)であるため、M はディメンション 0(最も左のディメンション)にあります。この中間配列は、縮退したディメンションのブロードキャストを使用して 4x3x1 行列に追加し、4x3x2 配列の結果を生成できます。