PJRT 插件集成

本文档重点介绍了有关如何与 PJRT 集成以及如何测试 PJRT 与 JAX 的集成的建议。

如何与 PJRT 集成

第 1 步:实现 PJRT C API 接口

选项 A:您可以直接实现 PJRT C API。

选项 B:如果您能够针对 xla 代码库中的 C++ 代码进行构建(通过分叉或 bazel),还可以实现 PJRT C++ API 并使用 C→C++ 封装容器:

  1. 实现从基 PJRT 客户端(及相关 PJRT 类)继承的 C++ PJRT 客户端。下面是一些 C++ PJRT 客户端示例:pjrt_stream_executor_client.htfrt_cpu_pjrt_client.h
  2. 实现一些不属于 C++ PJRT 客户端的 C API 方法:
    • PJRT_Client_Create。以下是一些伪代码示例(假设 GetPluginPjRtClient 返回了上面实现的 C++ PJRT 客户端):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

请注意,PJRT_Client_Create 可以接受从框架传递的选项。此处展示了 GPU 客户端如何使用此功能的示例。

借助封装容器,您无需实现其余的 C API。

第 2 步:实现 GetPjRtApi

您需要实现一个方法 GetPjRtApi,该方法会返回一个 PJRT_Api*,其中包含指向 PJRT C API 实现的函数指针。以下示例假定通过封装容器进行实现(类似于 pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

第 3 步:测试 C API 实现

您可以调用 RegisterPjRtCApiTestFactory 来针对基本 PJRT C API 行为运行一小组测试。

如何使用 JAX 中的 PJRT 插件

第 1 步:设置 JAX

您可以使用 JAX 每夜 build

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

或者从源代码构建 JAX

目前,您需要将 jaxlib 版本与 PJRT C API 版本进行匹配。通常,只需使用与您要针对其构建插件的 TF 提交版本同一天的 jaxlib 每夜版本即可,例如

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

您还可以根据要构建的 XLA 提交版本,从源代码构建 jaxlib(操作说明)。

我们很快就会开始支持 ABI 兼容性。

第 2 步:使用 jax_plugins 命名空间或设置 entry_point

您可以通过以下两种方式让 JAX 发现您的插件。

  1. 使用命名空间软件包(参考文档)。在 jax_plugins 命名空间软件包下定义一个全局唯一的模块(即只需创建一个 jax_plugins 目录,然后在其下定义模块即可)。以下是一个目录结构示例:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 使用软件包元数据 (ref)。如果通过 pyproject.toml 或 setup.py 构建软件包,请在 jax_plugins 组下添加指向完整模块名称的入口点,以宣传您的插件模块名称。以下是通过 pyproject.toml 或 setup.py 设置的示例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

以下是使用选项 2 实现 openxla-pjrt-plugin 的示例:https://github.com/openxla/openxla-pjrt-plugin/pull/119,https://github.com/openxla/openxla-pjrt-plugin/pull/120

第 3 步:实现 initialize() 方法

您需要在 Python 模块中实现 initialize() 方法以注册插件,例如:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

如需了解如何使用 xla_bridge.register_plugin,请参阅此处。它目前是私有方法。我们将在未来发布一个公开 API。

您可以运行以下代码行,以验证插件是否已注册,并在无法加载插件时抛出错误。

jax.config.update("jax_platforms", "my_plugin")

JAX 可以有多个后端/插件。您可以通过以下几种方法确保您的插件被用作默认后端:

  • 方法 1:在程序开头运行 jax.config.update("jax_platforms", "my_plugin")
  • 方案 2:设置 ENV JAX_PLATFORMS=my_plugin
  • 方法 3:调用 xb.register_plugin 时设置足够高的优先级(默认值为 400,高于其他现有后端)。请注意,只有在 JAX_PLATFORMS='' 时,系统才会使用优先级最高的后端。JAX_PLATFORMS 的默认值为 '',但有时会被覆盖。

如何使用 JAX 进行测试

以下是一些基本测试用例:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(我们很快就会添加有关针对您的插件运行 jax 单元测试的说明!)

如需查看更多 PJRT 插件示例,请参阅 PJRT 示例