PJRT 外掛程式整合

本文件著重於說明如何整合 PJRT 的最佳做法,以及如何測試 PJRT 與 JAX 的整合。

如何整合 PJRT

步驟 1:實作 PJRT C API 介面

選項 A:您可以直接實作 PJRT C API。

選項 B:如果您可以針對 xla 存放區中的 C++ 程式碼進行建構 (透過分支或 Bazel),也可以實作 PJRT C++ API,並使用 C→C++ 包裝函式:

  1. 實作從基礎 PJRT 用戶端 (以及相關的 PJRT 類別) 繼承的 C++ PJRT 用戶端。以下是一些 C++ PJRT 用戶端的範例:pjrt_stream_executor_client.htfrt_cpu_pjrt_client.h
  2. 實作幾個不在 C++ PJRT 用戶端中的 C API 方法:
    • PJRT_Client_Create。以下是一些範例模擬程式碼 (假設 GetPluginPjRtClient 會傳回上述實作的 C++ PJRT 用戶端):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

請注意,PJRT_Client_Create 可採用從架構傳遞的選項。這裡是 GPU 用戶端如何使用這項功能的範例。

有了包裝函式,您就不需要實作其他 C API。

步驟 2:實作 GetPjRtApi

您需要實作方法 GetPjRtApi,該方法會傳回 PJRT_Api*,其中包含指向 PJRT C API 實作的函式指標。以下是假設透過包裝函式實作的範例 (類似 pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

步驟 3:測試 C API 實作

您可以呼叫 RegisterPjRtCApiTestFactory,針對基本 PJRT C API 行為執行一小組測試。

如何使用 JAX 中的 PJRT 外掛程式

步驟 1:設定 JAX

您可以使用 JAX 每晚版本

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

從原始碼建構 JAX

目前,您需要將 jaxlib 版本與 PJRT C API 版本配對。通常只要使用與建構外掛程式所用的 TF 版本相同的 jaxlib 夜間版本即可,例如

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

您也可以根據要建構的 XLA 版本,從來源建構 jaxlib (操作說明)。

我們很快就會開始支援 ABI 相容性。

步驟 2:使用 jax_plugins 命名空間或設定 entry_point

您可以透過兩種方式讓 JAX 偵測外掛程式。

  1. 使用命名空間套件 (ref)。在 jax_plugins 命名空間套件下定義全域不重複的模組 (也就是建立 jax_plugins 目錄,並在其下定義模組)。以下是目錄結構範例:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 使用套件中繼資料 (參考資料)。如果您是透過 pyproject.toml 或 setup.py 建構套件,請在 jax_plugins 群組下方加入入口點,指向完整的模組名稱,以宣傳外掛程式模組名稱。以下是透過 pyproject.toml 或 setup.py 設定的範例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

以下是使用選項 2 實作 openxla-pjrt-plugin 的範例:https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

步驟 3:實作 initialize() 方法

您需要在 Python 模組中實作 initialize() 方法,才能註冊外掛程式,例如:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

如要瞭解如何使用 xla_bridge.register_plugin,請參閱這篇文章。目前為不公開方法。我們會在日後發布公開 API。

您可以執行下列程式碼行,驗證外掛程式是否已註冊,並在無法載入時發出錯誤。

jax.config.update("jax_platforms", "my_plugin")

JAX 可能會有多個後端/外掛程式。您可以透過以下幾種方式,確保插件會用做預設後端:

  • 方法 1:在程式開頭執行 jax.config.update("jax_platforms", "my_plugin")
  • 做法 2:設定 ENV JAX_PLATFORMS=my_plugin
  • 選項 3:在呼叫 xb.register_plugin 時設定足夠高的優先順序 (預設值為 400,高於其他現有後端)。請注意,只有在 JAX_PLATFORMS='' 時,系統才會使用最高優先順序的後端。JAX_PLATFORMS 的預設值為 '',但有時會遭到覆寫。

如何使用 JAX 進行測試

以下是一些可嘗試的基本測試案例:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(我們很快就會新增針對外掛程式執行 jax 單元測試的操作說明!)

如需更多 PJRT 外掛程式的範例,請參閱「PJRT 範例」。