XLA:优化机器学习编译器

OpenXLA 是一种针对特定领域的线性代数编译器,可以加快 TensorFlow 模型的运行速度,而且可能无需更改源代码。

简介

运行 TensorFlow 程序时,所有操作均由 TensorFlow 执行程序单独执行。每个 TensorFlow 操作都有一个预编译的 GPU 内核实现,系统会将执行程序分派给该实现。

XLA 提供了一种运行模型的替代模式:它会将 TensorFlow 图编译成一系列专门为给定模型生成的计算内核。由于这些内核是模型独有的,因此它们可以利用模型专属信息进行优化。例如,我们来看看 XLA 在简单 TensorFlow 计算环境中进行的优化:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

如果在不使用 XLA 的情况下运行,图会启动三个内核:一个用于乘法,一个用于加法,一个用于减法。不过,XLA 可以优化该图,使其只需一次内核启动即可计算结果。它通过将加法、乘法和减法“融合”到单个 GPU 内核中来实现这一点。此外,这种融合操作不会将 y*zx+y*z 生成的中间值写出到内存中;而是将这些中间计算的结果直接“流式传输”给用户,同时将它们完全保留在 GPU 寄存器中。融合是 XLA 采用的最重要的一项优化措施。 内存带宽通常是硬件加速器上最稀缺的资源,因此减少内存操作是提高性能的最佳方式之一。

为 TensorFlow 模型启用 XLA

使用 tf.function(jit_compile=True) 进行显式编译

显式编译 API 提供精细的控制,用于选择应编译哪些函数。例如,以下执行 MNIST 训练的 TensorFlow 函数使用 XLA 进行编译:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API 具有必须编译语义:要么使用 XLA 编译整个函数,要么抛出 errors.InvalidArgumentError 异常。目前,如果维度无法推断,也就是说,如果不运行整个计算就无法推断所有张量的维度,则 XLA 无法编译这些函数。例如,以下函数将无法编译:

@tf.function
def not_compilable(x):
  return tf.unique(x)

不过,形状可能在各次运行之间有所不同:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

如需查看更详细的用法示例,请参阅教程 Colab,并观看有关 jit_compile=True 用法的教程视频

使用 Keras

对于 Keras 模型,可以将 jit_compile=True 设置为 model.compile 的参数:

model.compile(optimizer="adam", jit_compile=True)

分布式策略的使用

XLA:GPU 可用于 TF 分布式策略(MirroredStrategyMultiWorkerMirroredStrategy),只需使用 jit_compile=True 注释阶跃函数即可:

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

自动聚类

如需开始在 TensorFlow 模型中使用 XLA,而无需进行任何更改,一种简单方法是启用自动聚类,此功能会自动在 TensorFlow 函数中查找聚类(连接的子图),这些聚类可以使用 XLA 进行编译和执行。您可以通过设置 TF_XLA_FLAGS 环境变量在 GPU 上启用自动聚类功能:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

自动聚类目前已针对 GPU 工作负载进行了优化,但您也可以通过另外使用 --tf_xla_cpu_global_jit 标志在 CPU 上启用它:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

如需查看详细的用法示例,请参阅自动聚类教程 Colab

使用 tfcompile 对 CPU 进行 AOT(预先)编译

您还可以使用独立的 tfcompile 工具,此工具会将 TensorFlow 图转换为可执行代码(仅适用于 x86-64 CPU)。

检查已编译的程序

XLA 提供了内省工具,可让您检查生成的程序。如需转储生成的程序,请使用环境变量 XLA_FLAGS

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

执行转储后,您可以在 /tmp/generated 中找到以下文件:

  • module_XXXX.*_optimizations.txt:生成的 XLA 程序,每个已编译的聚类对应一个。在提交 XLA 错误报告时附上这些错误对我们很有帮助!

  • module_XXXX.ir-*.ll:采用 LLVM 中间表示法生成的文件,其中包含 NVPTX 内建函数。

  • module_XXXX.ptx:生成的 PTX 文件。

您还可以使用以下命令转储直观呈现 TensorFlow 图内部 XLA 聚类的嵌入情况的图:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

可重现的 bug 报告

如果 bug 报告包含生成的 XLA 程序的转储和所用的自动聚类嵌入,则重现则更容易重现。如需为使用自动聚类运行的 TensorFlow 程序生成此类报告,请启动以下命令:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

提交 bug 时,请附加 /tmp/generated 目录(如上所述)的内容。

如果可能,请尝试使用 run_hlo_module 并在生成的程序上以迭代方式运行 bug,从而将 bug 隔离到单个 XLA 程序中。

补充阅读材料

XLA 前端

除了 TensorFlow 之外,还可以通过以下方式生成 XLA 程序:

  • JAX:Python+NumPy 程序的可组合转换
  • Julia:用于科学计算的 Julia 语言
  • PyTorch:PyTorch 框架
  • Nx:适用于 Elixir 编程语言的数值计算库

讲话

通过 jit_compile=True 使用 TF 中的 XLA

XLA 概览