背景
PJRT 是我们希望添加到机器学习生态系统中的统一 Device API。长期愿景是:(1) 框架(JAX、TF 等)将调用 PJRT,PJRT 具有对框架不透明的设备特定实现;(2) 每台设备专注于实现 PJRT API,并且可能对框架不透明。
本文档重点介绍了有关如何与 PJRT 集成以及如何测试 PJRT 与 JAX 集成的建议。
如何与 PJRT 集成
第 1 步:实现 PJRT C API 接口
选项 A:您可以直接实现 PJRT C API。
方式 B:如果您能够在 xla 代码库中基于 C++ 代码进行构建(通过分支或 bazel),则还可以实现 PJRT C++ API 并使用 C→C++ 封装容器:
- 实现从基本 PJRT 客户端(和相关的 PJRT 类)继承的 C++ PJRT 客户端。以下是一些 C++ PJRT 客户端的示例:pjrt_stream_executor_client.h、tfrt_cpu_pjrt_client.h。
- 实现一些不属于 C++ PJRT 客户端的 C API 方法:
- PJRT_Client_Create。以下是一些伪代码示例(假设
GetPluginPjRtClient
返回上面实现的 C++ PJRT 客户端):
- PJRT_Client_Create。以下是一些伪代码示例(假设
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
请注意,PJRT_Client_Create 可以采用从框架传递的选项。点此示例展示了 GPU 客户端如何使用此功能。
- [可选] PJRT_TopologyDescription_Create。
- [可选] PJRT_Plugin_Initialize。这是一个一次性的插件设置,框架将在调用任何其他函数之前调用该设置。
- [可选] PJRT_Plugin_Attributes。
第 2 步:实现 GetPjRtApi
您需要实现一个 GetPjRtApi
方法,该方法会返回一个 PJRT_Api*
,其中包含指向 PJRT C API 实现的函数指针。下面的示例假设通过封装容器实现(类似于 pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
第 3 步:测试 C API 实现
您可以调用 RegisterPjRtCApiTestFactory,针对基本 PJRT C API 行为运行一小组测试。
如何使用 JAX 中的 PJRT 插件
第 1 步:设置 JAX
您可以在每晚使用 JAX
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
目前,您需要将 jaxlib 版本与 PJRT C API 版本进行匹配。通常,使用构建插件所基于的 TF 提交内容当天的 jaxlib 夜间版本就足够了,例如
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
您也可以在构建所依据的 XLA 提交位置从源代码构建 jaxlib(instructions)。
我们将很快开始支持 ABI 兼容性。
第 2 步:使用 jax_plugins 命名空间或设置 entry_point
JAX 可通过两种方式发现您的插件。
- 使用命名空间软件包(参考文档)。在
jax_plugins
命名空间软件包下定义一个全局唯一的模块(即只需创建一个jax_plugins
目录并在该目录下定义您的模块即可)。以下是目录结构示例:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- 使用软件包元数据(参考)。如果通过 pyproject.toml 或 setup.py 构建软件包,请通过在
jax_plugins
组下添加一个指向完整模块名称的入口点来通告您的插件模块名称。以下是通过 pyproject.toml 或 setup.py 创建的示例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
以下示例说明了如何使用方法 2 实现 openxla-pjrt-plugin:https://github.com/openxla/openxla-pjrt-plugin/pull/119,、https://github.com/openxla/openxla-pjrt-plugin/pull/120
第 3 步:实现 initialize() 方法
您需要在 Python 模块中实现 initialize() 方法来注册插件,例如:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
请参阅此处,了解如何使用 xla_bridge.register_plugin
。它目前是一种私有方法。未来我们会发布公共 API。
您可以运行下面这行代码来验证插件已注册,如果插件无法加载,则引发错误。
jax.config.update("jax_platforms", "my_plugin")
JAX 可能有多个后端/插件。您可以通过以下几种方法确保您的插件用作默认后端:
- 方式 1:在程序的开头运行
jax.config.update("jax_platforms", "my_plugin")
。 - 方法 2:设置 ENV
JAX_PLATFORMS=my_plugin
。 - 选项 3:调用 xb.register_plugin 时设置足够高的优先级(默认值为 400,高于其他现有后端)。请注意,只有在
JAX_PLATFORMS=''
时,系统才会使用优先级最高的后端。JAX_PLATFORMS
的默认值为''
,但有时会被覆盖。
如何使用 JAX 进行测试
可以尝试的一些基本测试用例:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(我们很快就会添加有关如何针对您的插件运行 jax 单元测试的说明!)
示例:JAX CUDA 插件
- 通过封装容器 (pjrt_c_api_gpu.h) 实现 PJRT C API。
- 设置软件包的入口点 (setup.py)。
- 实现 initialize() 方法 (__init__.py)。
- 可以使用针对 CUDA 的任何 Jax 测试进行测试。```