OpenXLA は、線形代数のためのドメイン固有のコンパイラで、ソースコードを変更せずに TensorFlow モデルを高速化することができます。
はじめに
TensorFlow プログラムを実行すると、すべてのオペレーションが TensorFlow エグゼキュータによって個別に実行されます。TensorFlow の各オペレーションは、エグゼキュータによってプリコンパイル済み GPU カーネル実装にディスパッチされます。
XLA では、モデルの実行を代替するモードが用意されています。代替モードでは、TensorFlow グラフが、所定のモデル用に生成された一連のコンピューティング カーネルにコンパイルされます。こうしたカーネルはモデルに固有であるため、モデル固有の情報を利用して最適化できます。TensorFlow で単純な演算を行う場合の XLA による最適化の例を次に示します。
def model_fn(x, y, z):
return tf.reduce_sum(x + y * z)
XLA なしで実行する場合、乗算、加算、削減用として 3 つのカーネルがグラフで起動されます。XLA を使用すると、グラフが最適化され、1 回のカーネル起動で演算が行われます。これは、加算、乗算、削減を単一の GPU カーネルに「融合」することで行われます。さらに、この融合された演算では、y*z
と x+y*z
で生成された中間値がメモリに書き出されません。代わりに、これらの中間演算の結果を GPU レジスタにすべて保持しながら、ユーザーに直接「ストリーミング」します。融合は、XLA の唯一かつ重要な最適化手法です。ハードウェア アクセラレータにおいては一般にメモリ帯域幅のリソース上の制約が大きく、このようにメモリ操作を省くことが、パフォーマンスを改善するうえで有効です。
TensorFlow モデルに対して XLA を有効にする
tf.function(jit_compile=True)
を使用した明示的なコンパイル
明示的コンパイル API を使用すると、コンパイルする関数を細かく制御できます。たとえば、次の TensorFlow 関数(MNIST トレーニングを実行する)は XLA でコンパイルされます。
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
jit_compile
API には「コンパイル必須」セマンティクスがあります。これにより、関数全体が XLA でコンパイルされるか、errors.InvalidArgumentError
例外がスローされます。XLA では現在、次元を推定できない関数はコンパイルできません。つまり、計算全体を実行しないと推定できないテンソルの次元がある場合は、コンパイルできないということです。たとえば、次の関数はコンパイルできません。
@tf.function
def not_compilable(x):
return tf.unique(x)
ただし、シェイプは実行ごとに変えることができます。
@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
return a + b
recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))
使用例について詳しくは、チュートリアルの Colab や jit_compile=True
の使用に関する チュートリアル動画 をご覧ください。
Keras での使用
Keras モデルの場合、jit_compile=True
は model.compile
の引数として設定できます。
model.compile(optimizer="adam", jit_compile=True)
分散戦略での使用
XLA:GPU は、ステップ関数に jit_compile=True
をアノテーションすることで、TF 分散戦略(MirroredStrategy
または MultiWorkerMirroredStrategy
)で使用できます。
@tf.function(jit_compile=True)
def step_fn():
t = tf.ones(shape=[100], dtype=tf.float32)
ctx = tf.distribute.get_replica_context()
return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)
@tf.function
def run_fn():
return strategy.run(step_fn)
自動クラスタリング
TensorFlow モデルで何も変更を加えずに XLA を使用する簡単な方法は、自動クラスタリングを有効にすることです。自動クラスタリングにより、XLA を使用してコンパイルと実行が行える TensorFlow 関数内のクラスタ(連結サブグラフ)が自動的に検索されます。GPU での自動クラスタリングは、TF_XLA_FLAGS
環境変数を設定することで有効にできます。
$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program
現在、自動クラスタリングは GPU ワークロード用に最適化されていますが、次のように --tf_xla_cpu_global_jit
フラグを追加することで、CPU に対して有効にすることもできます。
$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program
詳細な使用例については、自動クラスタリングに関するチュートリアルの Colab をご覧ください。
tfcompile
による CPU の AOT(事前)コンパイル
スタンドアロンの tfcompile
ツールを使用して、TensorFlow グラフを実行可能コード(x86-64 CPU のみ)に変換することもできます。
コンパイルされたプログラムの検査
XLA には、生成されたプログラムを検査できるイントロスペクション機能が用意されています。生成されたプログラムをダンプするには、環境変数 XLA_FLAGS
を次のように使用します。
$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program
ダンプが行われると、/tmp/generated
に次のファイルが生成されます。
module_XXXX.*_optimizations.txt
: 生成された XLA プログラム(コンパイルされたクラスタごとに 1 つ)。XLA バグレポートの送信時に添付していただけるととても役立ちます。module_XXXX.ir-*.ll
: NVPTX を組み込むことで、LLVM による中間表現として生成されたファイル。module_XXXX.ptx
: 生成された PTX ファイル。
次のようにして、TensorFlow グラフ内に埋め込まれた XLA クラスタを可視化するグラフをダンプすることもできます。
$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"
再現可能なバグレポート
バグレポートに、生成された XLA プログラムと使用された自動クラスタリングの埋め込みのダンプが含まれていると、再現がはるかに簡単になります。自動クラスタリングを使用して実行する TensorFlow プログラムでこれらを生成するには、次のように起動します。
$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
my/tensorflow/program"
バグを報告する際は、上記で参照している /tmp/generated
ディレクトリの内容を添付します。
可能であれば、run_hlo_module
を使用して生成されたプログラムで繰り返し実行することで、バグを単一の XLA プログラムまで切り分けるようにしてください。
関連情報
- OpenXLA のドキュメント OpenXLA のドキュメント
- 既知の問題: XLA+TF に関する既知の問題のリスト
- XLA - コンパイル済みの TensorFlow: Google Developers ブログをご覧ください
- GitHub の XLA ソースもご確認ください
XLA フロントエンド
XLA プログラムは、TensorFlow とは別に、次の手段でも生成できます。
- JAX: Python+NumPy プログラムの変換構成ツール
- Julia: 科学計算のためのプログラミング言語
- PyTorch: PyTorch フレームワーク
- Nx: Elixir プログラミング言語向け数値計算ライブラリ