本文件著重於說明如何整合 PJRT 的最佳做法,以及如何測試 PJRT 與 JAX 的整合。
如何整合 PJRT
步驟 1:實作 PJRT C API 介面
選項 A:您可以直接實作 PJRT C API。
選項 B:如果您可以針對 xla 存放區中的 C++ 程式碼進行建構 (透過分支或 Bazel),也可以實作 PJRT C++ API,並使用 C→C++ 包裝函式:
- 實作從基礎 PJRT 用戶端 (以及相關的 PJRT 類別) 繼承的 C++ PJRT 用戶端。以下是一些 C++ PJRT 用戶端的範例:pjrt_stream_executor_client.h、tfrt_cpu_pjrt_client.h。
- 實作幾個不在 C++ PJRT 用戶端中的 C API 方法:
- PJRT_Client_Create。以下是一些範例模擬程式碼 (假設
GetPluginPjRtClient
會傳回上述實作的 C++ PJRT 用戶端):
- PJRT_Client_Create。以下是一些範例模擬程式碼 (假設
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
請注意,PJRT_Client_Create 可採用從架構傳遞的選項。這裡是 GPU 用戶端如何使用這項功能的範例。
- [選用] PJRT_TopologyDescription_Create。
- [選用] PJRT_Plugin_Initialize。這是一次性的外掛程式設定,會在呼叫任何其他函式之前由架構呼叫。
- [選用] PJRT_Plugin_Attributes。
有了包裝函式,您就不需要實作其他 C API。
步驟 2:實作 GetPjRtApi
您需要實作方法 GetPjRtApi
,該方法會傳回 PJRT_Api*
,其中包含指向 PJRT C API 實作的函式指標。以下是假設透過包裝函式實作的範例 (類似 pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
步驟 3:測試 C API 實作
您可以呼叫 RegisterPjRtCApiTestFactory,針對基本 PJRT C API 行為執行一小組測試。
如何使用 JAX 中的 PJRT 外掛程式
步驟 1:設定 JAX
您可以使用 JAX 每晚版本
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
目前,您需要將 jaxlib 版本與 PJRT C API 版本配對。通常只要使用與建構外掛程式所用的 TF 版本相同的 jaxlib 夜間版本即可,例如
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
您也可以根據要建構的 XLA 版本,從來源建構 jaxlib (操作說明)。
我們很快就會開始支援 ABI 相容性。
步驟 2:使用 jax_plugins 命名空間或設定 entry_point
您可以透過兩種方式讓 JAX 偵測外掛程式。
- 使用命名空間套件 (ref)。在
jax_plugins
命名空間套件下定義全域不重複的模組 (也就是建立jax_plugins
目錄,並在其下定義模組)。以下是目錄結構範例:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- 使用套件中繼資料 (參考資料)。如果您是透過 pyproject.toml 或 setup.py 建構套件,請在
jax_plugins
群組下方加入入口點,指向完整的模組名稱,以宣傳外掛程式模組名稱。以下是透過 pyproject.toml 或 setup.py 設定的範例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
以下是使用選項 2 實作 openxla-pjrt-plugin 的範例:https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
步驟 3:實作 initialize() 方法
您需要在 Python 模組中實作 initialize() 方法,才能註冊外掛程式,例如:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
如要瞭解如何使用 xla_bridge.register_plugin
,請參閱這篇文章。目前為不公開方法。我們會在日後發布公開 API。
您可以執行下列程式碼行,驗證外掛程式是否已註冊,並在無法載入時發出錯誤。
jax.config.update("jax_platforms", "my_plugin")
JAX 可能會有多個後端/外掛程式。您可以透過以下幾種方式,確保插件會用做預設後端:
- 方法 1:在程式開頭執行
jax.config.update("jax_platforms", "my_plugin")
。 - 做法 2:設定 ENV
JAX_PLATFORMS=my_plugin
。 - 選項 3:在呼叫 xb.register_plugin 時設定足夠高的優先順序 (預設值為 400,高於其他現有後端)。請注意,只有在
JAX_PLATFORMS=''
時,系統才會使用最高優先順序的後端。JAX_PLATFORMS
的預設值為''
,但有時會遭到覆寫。
如何使用 JAX 進行測試
以下是一些可嘗試的基本測試案例:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(我們很快就會新增針對外掛程式執行 jax 單元測試的操作說明!)
如需更多 PJRT 外掛程式的範例,請參閱「PJRT 範例」。