PJRT 外掛程式整合

背景

PJRT 是我們想新增至機器學習生態系統的統一 Device API。長期願景是:(1) 架構 (JAX、TF 等) 將呼叫 PJRT,其裝置專屬的實作項目與架構不透明;(2) 每部裝置都著重在實作 PJRT API,架構則不透明。

本文件著重說明如何與 PJRT 整合,以及如何測試 PJRT 與 JAX 的整合。

如何與 PJRT 整合

步驟 1:實作 PJRT C API 介面

方法 A:您可以直接實作 PJRT C API。

方法 B:如果能夠在 xla 存放區中根據 C++ 程式碼進行建構 (透過分層或 bazel),也可以實作 PJRT C++ API 並使用 C→C++ 包裝函式:

  1. 實作從基本 PJRT 用戶端 (和相關 PJRT 類別) 繼承的 C++ PJRT 用戶端。以下是 C++ PJRT 用戶端的範例:pjrt_stream_executor_client.htfrt_cpu_pjrt_client.h
  2. 實作一些不屬於 C++ PJRT 用戶端的 C API 方法:
    • PJRT_Client_Create。以下是一些虛擬程式碼範例 (假設 GetPluginPjRtClient 傳回上述實作的 C++ PJRT 用戶端):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

請注意,PJRT_Client_Create 可以從架構傳遞的選項。以下是 GPU 用戶端如何使用這項功能的範例。

有了包裝函式,您不需要實作其餘的 C API。

步驟 2:實作 GetPjRtApi

您必須實作 GetPjRtApi 方法,該方法會傳回 PJRT_Api*,其中包含指向 PJRT C API 實作的函式指標。以下範例假設透過包裝函式 (類似於 pjrt_c_api_cpu.cc) 導入:

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

步驟 3:測試 C API 實作

您可以呼叫 RegisterPjRtCApiTestFactory,執行一組基本 PJRT C API 行為的測試。

如何使用 JAX 的 PJRT 外掛程式

步驟 1:設定 JAX

您可以每晚使用 JAX

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

從原始碼建構 JAX

目前,您需要比對 jaxlib 版本與 PJRT C API 版本。您通常可以在要建構外掛程式的 TF 修訂版本的同一天內,使用 jaxlib 夜間版本,例如

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

您也可以在要建構的 XLA 修訂版本上,從原始碼建構 jaxlib (instructions)。

我們很快就會開始支援 ABI 相容性。

步驟 2:使用 jax_plugins 命名空間或設定 Item_point

JAX 有兩個選項可用來尋找外掛程式。

  1. 使用命名空間套件 (參考資料)。在 jax_plugins 命名空間套件下定義全域專屬模組 (例如,只要建立 jax_plugins 目錄,並在其下方定義模組)。以下是目錄結構範例:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 使用套件中繼資料 (參考資料)。如果是透過 pyproject.toml 或 setup.py 建構套件,請在 jax_plugins 群組底下加入指向完整模組名稱的進入點,即可通告外掛程式模組的名稱。以下是透過 pyproject.toml 或 setup.py 的範例:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

以下是 Openxla-pjrt-plugin 如何使用「選項 2」實作的示例:https://github.com/openxla/openxla-pjrt-plugin/pull/119,https://github.com/openxla/openxla-pjrt-plugin/pull/120

步驟 3:實作 initialize() 方法

您需要在 Python 模組中實作 initialize() 方法,才能註冊外掛程式,例如:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

如要瞭解如何使用 xla_bridge.register_plugin,請參閱這裡。這是私人方法。我們日後會發布公用 API。

您可以執行以下這行程式碼來確認外掛程式是否已註冊;如果無法載入,則引發錯誤。

jax.config.update("jax_platforms", "my_plugin")

JAX 可能有多個後端/外掛程式。以下幾個方法可確保外掛程式設為預設後端:

  • 方法 1:在程式開始執行 jax.config.update("jax_platforms", "my_plugin")
  • 方法 2:設定 ENV JAX_PLATFORMS=my_plugin
  • 選項 3:呼叫 xb.register_plugin 時,設定夠高的優先順序 (預設值為 400,高於其他現有的後端)。請注意,優先順序最高的後端只會在 JAX_PLATFORMS='' 時使用。JAX_PLATFORMS 的預設值為 '',但有時可能會遭到覆寫。

如何使用 JAX 進行測試

請試試下列基本測試案例:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(我們會盡快新增針對外掛程式執行 jax 單元測試的操作說明!)

範例:JAX CUDA 外掛程式

  1. 透過包裝函式 (pjrt_c_api_gpu.h) 實作 PJRT C API 實作。
  2. 設定套件的進入點 (setup.py)。
  3. 實作 initialize() 方法 (__init__.py)。
  4. 可以透過 CUDA 的任何 jax 測試進行測試。```