XLA: ML 用にコンパイラを最適化する

OpenXLA は線形代数用のドメイン固有のコンパイラです。ソースコードを変更せずに TensorFlow モデルを高速化できます。

はじめに

TensorFlow プログラムが実行されると、すべてのオペレーションが TensorFlow エグゼキュータによって個別に実行されます。各 TensorFlow オペレーションには、エグゼキュータがディスパッチするプリコンパイル済みの GPU カーネル実装があります。

XLA はモデルを実行する代替モードを提供します。これは、指定されたモデル用に特別に生成された一連のコンピューティング カーネルに TensorFlow グラフをコンパイルします。これらのカーネルはモデルに固有であるため、モデル固有の情報を利用して最適化できます。たとえば、単純な TensorFlow 計算のコンテキストで XLA が行う最適化を見てみましょう。

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

XLA を使用せずに実行すると、グラフは 3 つのカーネル(乗算用、加算用、リダクション用)を起動します。ただし、XLA を使用すると、単一のカーネル起動で結果を計算するようにグラフを最適化できます。これは、加算、乗算、削減を単一の GPU カーネルに「融合」することで実現されます。さらに、この融合された演算では、y*zx+y*z によって生成された中間値がメモリに書き出されません。代わりに、これらの中間計算の結果をユーザーに直接「ストリーミング」しながら、それらを完全に GPU レジスタに保持します。融合は XLA で最も重要な最適化です。通常、メモリ帯域幅はハードウェア アクセラレータで最も少ないリソースであるため、メモリ オペレーションを排除することが、パフォーマンスを向上させる最善の方法の一つです。

TensorFlow モデルで XLA を有効にする

tf.function(jit_compile=True) を使用した明示的なコンパイル

明示的コンパイル API を使用すると、コンパイルする関数をきめ細かく制御できます。たとえば、MNIST トレーニングを実行する次の TensorFlow 関数は XLA でコンパイルされます。

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API には、コンパイルが必須のセマンティクスがあります。関数全体が XLA でコンパイルされるか、errors.InvalidArgumentError 例外がスローされます。現在 XLA では、次元が推定できない関数、つまり、計算全体を実行せずにすべてのテンソルの次元を推測できない関数をコンパイルできません。たとえば、次の関数はコンパイルされません。

@tf.function
def not_compilable(x):
  return tf.unique(x)

ただし、形状は実行によって異なる場合があります。

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

詳細な使用例については、チュートリアルの Colab と、jit_compile=True の使用に関するチュートリアル動画をご覧ください。

Keras での使用方法

Keras モデルの場合、jit_compile=True を引数として model.compile に設定できます。

model.compile(optimizer="adam", jit_compile=True)

分散戦略と併用

XLA:GPU は、ステップ関数に jit_compile=True アノテーションを付けることで、TF 分散戦略(MirroredStrategy または MultiWorkerMirroredStrategy)で使用できます。

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

自動クラスタリング

変更せずに TensorFlow モデルで XLA の使用を開始する簡単な方法は、自動クラスタリングを有効にすることです。これにより、XLA を使用してコンパイルおよび実行できる TensorFlow 関数内のクラスタ(接続されたサブグラフ)が自動的に検出されます。GPU の自動クラスタリングは、TF_XLA_FLAGS 環境変数を設定することで有効にできます。

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

現在、自動クラスタリングは GPU ワークロード用に最適化されていますが、さらに --tf_xla_cpu_global_jit フラグを使用することで CPU で有効にすることもできます。

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

詳細な使用例については、自動クラスタリングのチュートリアルの Colab をご覧ください。

tfcompile を使用した CPU の AOT(事前)コンパイル

スタンドアロンの tfcompile ツールを使用して、TensorFlow グラフを実行可能コードに変換することもできます(x86-64 CPU のみ)。

コンパイルされたプログラムを検査する

XLA には、生成されたプログラムを検査できるイントロスペクション機能が用意されています。生成されたプログラムをダンプするには、環境変数 XLA_FLAGS を使用します。

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

ダンプの実行後、/tmp/generated に次のファイルが生成されます。

  • module_XXXX.*_optimizations.txt: 生成された XLA プログラム(コンパイルされたクラスタごとに 1 つ)。XLA バグレポートを送信するときに添付しておくと非常に便利です。

  • module_XXXX.ir-*.ll: NVPTX 組み込み関数を使用して、LLVM 中間表現で生成されたファイル。

  • module_XXXX.ptx: 生成された PTX ファイル。

また、TensorFlow グラフ内に XLA クラスタの埋め込みを可視化したグラフをダンプすることもできます。

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

再現可能なバグレポート

生成された XLA プログラムのダンプと使用された自動クラスタリング埋め込みが含まれていると、バグレポートを簡単に再現できます。自動クラスタリングで動作する TensorFlow プログラム用に生成するには、次のコマンドを実行します。

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

バグを報告する際は、上記の /tmp/generated ディレクトリの内容を添付してください。

可能であれば、run_hlo_module を使用し、生成されたプログラムで繰り返し実行することで、バグを単一の XLA プログラムに分離するようにしてください。

関連情報

XLA フロントエンド

XLA プログラムは、TensorFlow とは別に、次の方法で生成できます。

  • JAX: Python+NumPy プログラムのコンポーズ可能な変換
  • Julia: 科学計算のためのジュリア言語
  • PyTorch: PyTorch フレームワーク
  • Nx: Elixir プログラミング言語の数値コンピューティング ライブラリ

講演

jit_compile=True を使用して TF から XLA を使用する

XLA の概要