XLA Op の構造
次の HLO の例を考えてみましょう。
add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
add(exponential.183, broadcast.3115)
これは次のコンポーネントで構成されています。
- Op 名:
add.936- これはオペレーションの一意の名前です。
- 形状:
bf16[8,1,1280,16384]- これは Op の出力形状です。ここで、dtype は bf16 で、形状は
[8,1,1280,16384]です。
- これは Op の出力形状です。ここで、dtype は bf16 で、形状は
- レイアウト(タイリングあり):
3,2,0,1:T(8,128)(2,1)- これは、配列がメモリにどのように格納されるかを示しています。
3,2,0,1はメモリ内の軸の順序(列優先、行優先など)を示し、T(8,128)(2,1)は使用されるタイリングとパディングを示します。 - レイアウトは省略可能です。指定しない場合、タイリングは行われず、ディメンションは最上位から最下位の順に並べられていると見なされます。
- これは、配列がメモリにどのように格納されるかを示しています。
- オペレーション:
add- 実行中のオペレーション。ここでは Add です。これは Op 名にも記載されています。
- 引数:
exponential.183、broadcast.3115- このオペレーションは、一意の名前で指定された 2 つの引数を取ります。
別の例として、融合 Op を考えてみましょう。
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
kind=kCustom, calls=%all-reduce-scatter.3
前述のコンポーネントに加えて、次のコンポーネントで構成されています。
- 属性:
kindとcalls- これらは、実行されるオペレーション(この場合は融合)に関する詳細情報を提供します。
- メモリ位置(メモリ空間識別子):
S(1)- これは、配列が保存されるメモリ容量/ロケーションを示します。
S(1)は、この配列が VMEM(TPU 上)に存在することを示します。
- これは、配列が保存されるメモリ容量/ロケーションを示します。
- 入力引数
%fusion.32の形状とレイアウトの詳細
以降のセクションでは、シェイプ、レイアウト、メモリー空間識別子について説明します。タイリングについて詳しくは、タイル レイアウトをご覧ください。
図形
XLA ShapeProto プロトコル(xla_data.proto)は、N 次元配列(略して array)の次元数、サイズ、データ型を記述します。
用語、表記、規則
配列の実際の次元数は、サイズが 1 より大きい次元の数です。
N次元配列の場合、ディメンションには0からN-1までの番号が付けられます。ディメンションのサイズは非負の整数です。特に、サイズ 0 は有効です。ディメンション番号は便宜上の任意のラベルです。これらのディメンション番号の順序は、シェイプのレイアウトにおける特定のマイナー/メジャーの順序を意味するものではありません。レイアウトはLayoutProtoプロトによって決定されます。慣例により、ディメンションはディメンション番号の昇順でリストされます。たとえば、サイズ
[A x B x C]の 3 次元配列の場合、ディメンション 0 のサイズはA、ディメンション 1 のサイズはB、ディメンション 2 のサイズはCになります。XLA の一部のユーティリティは、Python のような負のインデックスもサポートしています。ディメンション -1 は最後のディメンションです(
N次元配列の場合、N-1に相当します)。たとえば、上記の 3 次元配列の場合、ディメンション -1 のサイズはC、ディメンション -2 のサイズはBになります。2 次元、3 次元、4 次元配列には、ディメンションに関連付けられた特定の文字がよくあります。たとえば、2 次元配列の場合は次のようになります。
- ディメンション 0:
y - ディメンション 1:
x
3D 配列の場合:
- ディメンション 0:
z - ディメンション 1:
y - ディメンション 2:
x
4D 配列の場合:
- ディメンション 0:
p - ディメンション 1:
z - ディメンション 2:
y - ディメンション 3:
x
- ディメンション 0:
ディメンションを受け取る XLA API の関数は、ディメンション番号の昇順で受け取ります。これは、ディメンションを
initializer_listとして渡すときに使用される順序と一致します。例:ShapeUtil::MakeShape(F32, {A, B, C, D})は、ディメンション サイズ配列がシーケンス
[A, B, C, D]で構成される形状を作成します。
レイアウト
LayoutProto プロトコルは、メモリ内で配列がどのように表現されるかを記述します。次のフィールドが含まれます。
message LayoutProto {
repeated int64 minor_to_major;
int64 tail_padding_alignment_in_elements;
...
}
マイナーからメジャーへのディメンションの順序付け
必須フィールドは minor_to_major のみです。このフィールドは、シェイプ内のディメンションのマイナーからメジャーへの順序を表します。minor_to_major の値は、配列のディメンションの順序(N 次元配列の場合は 0~N-1)です。最初の値が最もマイナーなディメンションで、最後の値が最もメジャーなディメンションです。最下位のディメンションは、線形メモリにレイアウトされた配列の要素をステップ実行するときに最も急速に変化するディメンションです。
たとえば、サイズ [2 x 3] の次の 2 次元配列について考えてみましょう。
a b c
d e f
ここで、ディメンション 0 はサイズ 2、ディメンション 1 はサイズ 3 です。レイアウトの minor_to_major フィールドが [0, 1] の場合、ディメンション 0 は最もマイナーなディメンションで、ディメンション 1 は最もメジャーなディメンションです。これは、線形メモリの次のレイアウトに対応します。
a d b e c f
0 から N-1 までのこのマイナーからメジャーへのディメンションの順序は、列優先(2 次元の場合)に似ています。ディメンションの単調な順序付けを想定すると、コードでこのレイアウトを「ディメンション 0 はマイナー」と簡単に参照することもできます。
一方、レイアウトの minor_to_major フィールドが [1, 0] の場合、線形メモリのレイアウトは次のようになります。
a b c d e f
N 次元配列のマイナーからメジャーへのディメンション順序 N-1 から 0 は、行優先(2 次元の場合)に似ています。ディメンションの単調な順序付けを想定すると、コードでこのレイアウトを「dim 0 がメジャー」と呼ぶこともできます。
デフォルトのマイナーからメジャーへの順序付け
新しく作成されたシェイプのデフォルトのレイアウトは「次元の順序はメジャーからマイナー」(つまり [N-1, ..., 0])です。
パディング
tail_padding_alignment_in_elements フィールドは、要素の数で tiled 配列の配置を定義します。タイリングを適用すると、要素の合計数がこの値の倍数になるまで、パディングされた要素がレイアウトの末尾に追加されます。
配列のインデックス付け
index_util.h のクラス IndexUtil は、形状とレイアウトが指定された場合に、多次元インデックスと線形インデックスを変換するためのユーティリティを提供します。多次元インデックスには、ディメンションごとに int64 インデックスが含まれます。線形インデックスは、配列を保持するバッファにインデックスを付ける単一の int64 値です。同じディレクトリにある shape_util.h と layout_util.h をご覧ください。これらは、シェイプとレイアウトの作成と操作を簡素化するユーティリティです。
メモリ空間識別子
HLO では、各配列にメモリ空間識別子(S(n) と記述)を付加できます。
S(0)(省略されることが多い)は、デバイスの高帯域幅メモリ(HBM)を表します。S(1)は、デバイス上の仮想メモリ(VMEM)を表します。S(2)、S(3)などは、デバイス固有の追加のメモリ空間に対応します。S(5)はホストメモリを示します。