OpenXLA は線形代数用のドメイン固有のコンパイラです。ソースコードを変更せずに TensorFlow モデルを高速化できます。
はじめに
TensorFlow プログラムが実行されると、すべてのオペレーションが TensorFlow エグゼキュータによって個別に実行されます。各 TensorFlow オペレーションには、エグゼキュータがディスパッチするプリコンパイル済みの GPU カーネル実装があります。
XLA はモデルを実行する代替モードを提供します。これは、指定されたモデル用に特別に生成された一連のコンピューティング カーネルに TensorFlow グラフをコンパイルします。これらのカーネルはモデルに固有であるため、モデル固有の情報を利用して最適化できます。たとえば、単純な TensorFlow 計算のコンテキストで XLA が行う最適化を見てみましょう。
def model_fn(x, y, z):
return tf.reduce_sum(x + y * z)
XLA を使用せずに実行すると、グラフは 3 つのカーネル(乗算用、加算用、リダクション用)を起動します。ただし、XLA を使用すると、単一のカーネル起動で結果を計算するようにグラフを最適化できます。これは、加算、乗算、削減を単一の GPU カーネルに「融合」することで実現されます。さらに、この融合された演算では、y*z
と x+y*z
によって生成された中間値がメモリに書き出されません。代わりに、これらの中間計算の結果をユーザーに直接「ストリーミング」しながら、それらを完全に GPU レジスタに保持します。融合は XLA で最も重要な最適化です。通常、メモリ帯域幅はハードウェア アクセラレータで最も少ないリソースであるため、メモリ オペレーションを排除することが、パフォーマンスを向上させる最善の方法の一つです。
TensorFlow モデルで XLA を有効にする
tf.function(jit_compile=True)
を使用した明示的なコンパイル
明示的コンパイル API を使用すると、コンパイルする関数をきめ細かく制御できます。たとえば、MNIST トレーニングを実行する次の TensorFlow 関数は XLA でコンパイルされます。
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
jit_compile
API には、コンパイルが必須のセマンティクスがあります。関数全体が XLA でコンパイルされるか、errors.InvalidArgumentError
例外がスローされます。現在 XLA では、次元が推定できない関数、つまり、計算全体を実行せずにすべてのテンソルの次元を推測できない関数をコンパイルできません。たとえば、次の関数はコンパイルされません。
@tf.function
def not_compilable(x):
return tf.unique(x)
ただし、形状は実行によって異なる場合があります。
@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
return a + b
recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))
詳細な使用例については、チュートリアルの Colab と、jit_compile=True
の使用に関するチュートリアル動画をご覧ください。
Keras での使用方法
Keras モデルの場合、jit_compile=True
を引数として model.compile
に設定できます。
model.compile(optimizer="adam", jit_compile=True)
分散戦略と併用
XLA:GPU は、ステップ関数に jit_compile=True
アノテーションを付けることで、TF 分散戦略(MirroredStrategy
または MultiWorkerMirroredStrategy
)で使用できます。
@tf.function(jit_compile=True)
def step_fn():
t = tf.ones(shape=[100], dtype=tf.float32)
ctx = tf.distribute.get_replica_context()
return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)
@tf.function
def run_fn():
return strategy.run(step_fn)
自動クラスタリング
変更せずに TensorFlow モデルで XLA の使用を開始する簡単な方法は、自動クラスタリングを有効にすることです。これにより、XLA を使用してコンパイルおよび実行できる TensorFlow 関数内のクラスタ(接続されたサブグラフ)が自動的に検出されます。GPU の自動クラスタリングは、TF_XLA_FLAGS
環境変数を設定することで有効にできます。
$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program
現在、自動クラスタリングは GPU ワークロード用に最適化されていますが、さらに --tf_xla_cpu_global_jit
フラグを使用することで CPU で有効にすることもできます。
$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program
詳細な使用例については、自動クラスタリングのチュートリアルの Colab をご覧ください。
tfcompile
を使用した CPU の AOT(事前)コンパイル
スタンドアロンの tfcompile
ツールを使用して、TensorFlow グラフを実行可能コードに変換することもできます(x86-64 CPU のみ)。
コンパイルされたプログラムを検査する
XLA には、生成されたプログラムを検査できるイントロスペクション機能が用意されています。生成されたプログラムをダンプするには、環境変数 XLA_FLAGS
を使用します。
$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program
ダンプの実行後、/tmp/generated
に次のファイルが生成されます。
module_XXXX.*_optimizations.txt
: 生成された XLA プログラム(コンパイルされたクラスタごとに 1 つ)。XLA バグレポートを送信するときに添付しておくと非常に便利です。module_XXXX.ptx
: 生成された PTX ファイル。
また、TensorFlow グラフ内に XLA クラスタの埋め込みを可視化したグラフをダンプすることもできます。
$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"
再現可能なバグレポート
生成された XLA プログラムのダンプと使用された自動クラスタリング埋め込みが含まれていると、バグレポートを簡単に再現できます。自動クラスタリングで動作する TensorFlow プログラム用に生成するには、次のコマンドを実行します。
$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
my/tensorflow/program"
バグを報告する際は、上記の /tmp/generated
ディレクトリの内容を添付してください。
可能であれば、run_hlo_module
を使用し、生成されたプログラムで繰り返し実行することで、バグを単一の XLA プログラムに分離するようにしてください。
関連情報
- OpenXLA のドキュメント OpenXLA のドキュメント
- 既知の問題: XLA+TF に関する既知の問題のリスト
- XLA - コンパイル済み TensorFlow: Google Developers ブログを読む
- GitHub で XLA のソースをご確認ください。
XLA フロントエンド
XLA プログラムは、TensorFlow とは別に、次の方法で生成できます。
- JAX: Python+NumPy プログラムのコンポーズ可能な変換
- Julia: 科学計算のためのジュリア言語
- PyTorch: PyTorch フレームワーク
- Nx: Elixir プログラミング言語の数値コンピューティング ライブラリ