背景
PJRT 是我們想新增至機器學習生態系統的統一 Device API。長期願景是:(1) 架構 (JAX、TF 等) 將呼叫 PJRT,其裝置專屬的實作項目與架構不透明;(2) 每部裝置都著重在實作 PJRT API,架構則不透明。
本文件著重說明如何與 PJRT 整合,以及如何測試 PJRT 與 JAX 的整合。
如何與 PJRT 整合
步驟 1:實作 PJRT C API 介面
方法 A:您可以直接實作 PJRT C API。
方法 B:如果能夠在 xla 存放區中根據 C++ 程式碼進行建構 (透過分層或 bazel),也可以實作 PJRT C++ API 並使用 C→C++ 包裝函式:
- 實作從基本 PJRT 用戶端 (和相關 PJRT 類別) 繼承的 C++ PJRT 用戶端。以下是 C++ PJRT 用戶端的範例:pjrt_stream_executor_client.h、tfrt_cpu_pjrt_client.h。
- 實作一些不屬於 C++ PJRT 用戶端的 C API 方法:
- PJRT_Client_Create。以下是一些虛擬程式碼範例 (假設
GetPluginPjRtClient
傳回上述實作的 C++ PJRT 用戶端):
- PJRT_Client_Create。以下是一些虛擬程式碼範例 (假設
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
請注意,PJRT_Client_Create 可以從架構傳遞的選項。以下是 GPU 用戶端如何使用這項功能的範例。
- [選用] PJRT_TopologyDescription_Create。
- [選用] PJRT_Plugin_Initialize。這是一次性的外掛程式設定,架構在呼叫任何其他函式之前會先呼叫這項設定。
- [選用] PJRT_Plugin_Attributes。
有了包裝函式,您不需要實作其餘的 C API。
步驟 2:實作 GetPjRtApi
您必須實作 GetPjRtApi
方法,該方法會傳回 PJRT_Api*
,其中包含指向 PJRT C API 實作的函式指標。以下範例假設透過包裝函式 (類似於 pjrt_c_api_cpu.cc) 導入:
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
步驟 3:測試 C API 實作
您可以呼叫 RegisterPjRtCApiTestFactory,執行一組基本 PJRT C API 行為的測試。
如何使用 JAX 的 PJRT 外掛程式
步驟 1:設定 JAX
您可以每晚使用 JAX
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
目前,您需要比對 jaxlib 版本與 PJRT C API 版本。您通常可以在要建構外掛程式的 TF 修訂版本的同一天內,使用 jaxlib 夜間版本,例如
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
您也可以在要建構的 XLA 修訂版本上,從原始碼建構 jaxlib (instructions)。
我們很快就會開始支援 ABI 相容性。
步驟 2:使用 jax_plugins 命名空間或設定 Item_point
JAX 有兩個選項可用來尋找外掛程式。
- 使用命名空間套件 (參考資料)。在
jax_plugins
命名空間套件下定義全域專屬模組 (例如,只要建立jax_plugins
目錄,並在其下方定義模組)。以下是目錄結構範例:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- 使用套件中繼資料 (參考資料)。如果是透過 pyproject.toml 或 setup.py 建構套件,請在
jax_plugins
群組底下加入指向完整模組名稱的進入點,即可通告外掛程式模組的名稱。以下是透過 pyproject.toml 或 setup.py 的範例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
以下是 Openxla-pjrt-plugin 如何使用「選項 2」實作的示例:https://github.com/openxla/openxla-pjrt-plugin/pull/119,、https://github.com/openxla/openxla-pjrt-plugin/pull/120
步驟 3:實作 initialize() 方法
您需要在 Python 模組中實作 initialize() 方法,才能註冊外掛程式,例如:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
如要瞭解如何使用 xla_bridge.register_plugin
,請參閱這裡。這是私人方法。我們日後會發布公用 API。
您可以執行以下這行程式碼來確認外掛程式是否已註冊;如果無法載入,則引發錯誤。
jax.config.update("jax_platforms", "my_plugin")
JAX 可能有多個後端/外掛程式。以下幾個方法可確保外掛程式設為預設後端:
- 方法 1:在程式開始執行
jax.config.update("jax_platforms", "my_plugin")
。 - 方法 2:設定 ENV
JAX_PLATFORMS=my_plugin
。 - 選項 3:呼叫 xb.register_plugin 時,設定夠高的優先順序 (預設值為 400,高於其他現有的後端)。請注意,優先順序最高的後端只會在
JAX_PLATFORMS=''
時使用。JAX_PLATFORMS
的預設值為''
,但有時可能會遭到覆寫。
如何使用 JAX 進行測試
請試試下列基本測試案例:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(我們會盡快新增針對外掛程式執行 jax 單元測試的操作說明!)
範例:JAX CUDA 外掛程式
- 透過包裝函式 (pjrt_c_api_gpu.h) 實作 PJRT C API 實作。
- 設定套件的進入點 (setup.py)。
- 實作 initialize() 方法 (__init__.py)。
- 可以透過 CUDA 的任何 jax 測試進行測試。```