PJRT 插件集成

背景

PJRT 是我们希望添加到机器学习生态系统中的统一 Device API。长期愿景是:(1) 框架(JAX、TF 等)将调用 PJRT,PJRT 具有对框架不透明的设备特定实现;(2) 每台设备专注于实现 PJRT API,并且可能对框架不透明。

本文档重点介绍了有关如何与 PJRT 集成以及如何测试 PJRT 与 JAX 集成的建议。

如何与 PJRT 集成

第 1 步:实现 PJRT C API 接口

选项 A:您可以直接实现 PJRT C API。

方式 B:如果您能够在 xla 代码库中基于 C++ 代码进行构建(通过分支或 bazel),则还可以实现 PJRT C++ API 并使用 C→C++ 封装容器:

  1. 实现从基本 PJRT 客户端(和相关的 PJRT 类)继承的 C++ PJRT 客户端。以下是一些 C++ PJRT 客户端的示例:pjrt_stream_executor_client.htfrt_cpu_pjrt_client.h
  2. 实现一些不属于 C++ PJRT 客户端的 C API 方法:
    • PJRT_Client_Create。以下是一些伪代码示例(假设 GetPluginPjRtClient 返回上面实现的 C++ PJRT 客户端):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

请注意,PJRT_Client_Create 可以采用从框架传递的选项。点此示例展示了 GPU 客户端如何使用此功能。

使用封装容器时,您无需实现其余的 C API。

第 2 步:实现 GetPjRtApi

您需要实现一个 GetPjRtApi 方法,该方法会返回一个 PJRT_Api*,其中包含指向 PJRT C API 实现的函数指针。下面的示例假设通过封装容器实现(类似于 pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

第 3 步:测试 C API 实现

您可以调用 RegisterPjRtCApiTestFactory,针对基本 PJRT C API 行为运行一小组测试。

如何使用 JAX 中的 PJRT 插件

第 1 步:设置 JAX

您可以在每晚使用 JAX

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

从源代码构建 JAX

目前,您需要将 jaxlib 版本与 PJRT C API 版本进行匹配。通常,使用构建插件所基于的 TF 提交内容当天的 jaxlib 夜间版本就足够了,例如

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

您也可以在构建所依据的 XLA 提交位置从源代码构建 jaxlib(instructions)。

我们将很快开始支持 ABI 兼容性。

第 2 步:使用 jax_plugins 命名空间或设置 entry_point

JAX 可通过两种方式发现您的插件。

  1. 使用命名空间软件包(参考文档)。在 jax_plugins 命名空间软件包下定义一个全局唯一的模块(即只需创建一个 jax_plugins 目录并在该目录下定义您的模块即可)。以下是目录结构示例:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 使用软件包元数据(参考)。如果通过 pyproject.toml 或 setup.py 构建软件包,请通过在 jax_plugins 组下添加一个指向完整模块名称的入口点来通告您的插件模块名称。以下是通过 pyproject.toml 或 setup.py 创建的示例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

以下示例说明了如何使用方法 2 实现 openxla-pjrt-plugin:https://github.com/openxla/openxla-pjrt-plugin/pull/119,https://github.com/openxla/openxla-pjrt-plugin/pull/120

第 3 步:实现 initialize() 方法

您需要在 Python 模块中实现 initialize() 方法来注册插件,例如:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

请参阅此处,了解如何使用 xla_bridge.register_plugin。它目前是一种私有方法。未来我们会发布公共 API。

您可以运行下面这行代码来验证插件已注册,如果插件无法加载,则引发错误。

jax.config.update("jax_platforms", "my_plugin")

JAX 可能有多个后端/插件。您可以通过以下几种方法确保您的插件用作默认后端:

  • 方式 1:在程序的开头运行 jax.config.update("jax_platforms", "my_plugin")
  • 方法 2:设置 ENV JAX_PLATFORMS=my_plugin
  • 选项 3:调用 xb.register_plugin 时设置足够高的优先级(默认值为 400,高于其他现有后端)。请注意,只有在 JAX_PLATFORMS='' 时,系统才会使用优先级最高的后端。JAX_PLATFORMS 的默认值为 '',但有时会被覆盖。

如何使用 JAX 进行测试

可以尝试的一些基本测试用例:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(我们很快就会添加有关如何针对您的插件运行 jax 单元测试的说明!)

示例:JAX CUDA 插件

  1. 通过封装容器 (pjrt_c_api_gpu.h) 实现 PJRT C API。
  2. 设置软件包的入口点 (setup.py)。
  3. 实现 initialize() 方法 (__init__.py)。
  4. 可以使用针对 CUDA 的任何 Jax 测试进行测试。```