本文档重点介绍了有关如何与 PJRT 集成以及如何测试 PJRT 与 JAX 的集成的建议。
如何与 PJRT 集成
第 1 步:实现 PJRT C API 接口
选项 A:您可以直接实现 PJRT C API。
选项 B:如果您能够针对 xla 代码库中的 C++ 代码进行构建(通过分叉或 bazel),还可以实现 PJRT C++ API 并使用 C→C++ 封装容器:
- 实现从基 PJRT 客户端(及相关 PJRT 类)继承的 C++ PJRT 客户端。下面是一些 C++ PJRT 客户端示例:pjrt_stream_executor_client.h、tfrt_cpu_pjrt_client.h。
- 实现一些不属于 C++ PJRT 客户端的 C API 方法:
- PJRT_Client_Create。以下是一些伪代码示例(假设
GetPluginPjRtClient
返回了上面实现的 C++ PJRT 客户端):
- PJRT_Client_Create。以下是一些伪代码示例(假设
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
请注意,PJRT_Client_Create 可以接受从框架传递的选项。此处展示了 GPU 客户端如何使用此功能的示例。
- [可选] PJRT_TopologyDescription_Create。
- [可选] PJRT_Plugin_Initialize。这是一次性插件设置,框架会在调用任何其他函数之前调用此设置。
- [可选] PJRT_Plugin_Attributes。
借助封装容器,您无需实现其余的 C API。
第 2 步:实现 GetPjRtApi
您需要实现一个方法 GetPjRtApi
,该方法会返回一个 PJRT_Api*
,其中包含指向 PJRT C API 实现的函数指针。以下示例假定通过封装容器进行实现(类似于 pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
第 3 步:测试 C API 实现
您可以调用 RegisterPjRtCApiTestFactory 来针对基本 PJRT C API 行为运行一小组测试。
如何使用 JAX 中的 PJRT 插件
第 1 步:设置 JAX
您可以使用 JAX 每夜 build
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
或者从源代码构建 JAX。
目前,您需要将 jaxlib 版本与 PJRT C API 版本进行匹配。通常,只需使用与您要针对其构建插件的 TF 提交版本同一天的 jaxlib 每夜版本即可,例如
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
您还可以根据要构建的 XLA 提交版本,从源代码构建 jaxlib(操作说明)。
我们很快就会开始支持 ABI 兼容性。
第 2 步:使用 jax_plugins 命名空间或设置 entry_point
您可以通过以下两种方式让 JAX 发现您的插件。
- 使用命名空间软件包(参考文档)。在
jax_plugins
命名空间软件包下定义一个全局唯一的模块(即只需创建一个jax_plugins
目录,然后在其下定义模块即可)。以下是一个目录结构示例:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- 使用软件包元数据 (ref)。如果通过 pyproject.toml 或 setup.py 构建软件包,请在
jax_plugins
组下添加指向完整模块名称的入口点,以宣传您的插件模块名称。以下是通过 pyproject.toml 或 setup.py 设置的示例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
以下是使用选项 2 实现 openxla-pjrt-plugin 的示例:https://github.com/openxla/openxla-pjrt-plugin/pull/119,、https://github.com/openxla/openxla-pjrt-plugin/pull/120
第 3 步:实现 initialize() 方法
您需要在 Python 模块中实现 initialize() 方法以注册插件,例如:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
如需了解如何使用 xla_bridge.register_plugin
,请参阅此处。它目前是私有方法。我们将在未来发布一个公开 API。
您可以运行以下代码行,以验证插件是否已注册,并在无法加载插件时抛出错误。
jax.config.update("jax_platforms", "my_plugin")
JAX 可以有多个后端/插件。您可以通过以下几种方法确保您的插件被用作默认后端:
- 方法 1:在程序开头运行
jax.config.update("jax_platforms", "my_plugin")
。 - 方案 2:设置 ENV
JAX_PLATFORMS=my_plugin
。 - 方法 3:调用 xb.register_plugin 时设置足够高的优先级(默认值为 400,高于其他现有后端)。请注意,只有在
JAX_PLATFORMS=''
时,系统才会使用优先级最高的后端。JAX_PLATFORMS
的默认值为''
,但有时会被覆盖。
如何使用 JAX 进行测试
以下是一些基本测试用例:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(我们很快就会添加有关针对您的插件运行 jax 单元测试的说明!)
如需查看更多 PJRT 插件示例,请参阅 PJRT 示例。