背景
PJRT は、ML エコシステムに追加する統一された Device API です。長期的なビジョンは、(1)フレームワーク(JAX、TF など)が PJRT を呼び出す(フレームワークに対して不透明なデバイス固有の実装を持つ)、(2)各デバイスが PJRT API の実装に注力し、フレームワークに対して不透明になることです。
このドキュメントでは、PJRT との統合方法と JAX との PJRT 統合のテスト方法に関する推奨事項に焦点を当てます。
PJRT と統合する方法
ステップ 1: PJRT C API インターフェースを実装する
オプション A: PJRT C API を直接実装できます。
オプション B: xla リポジトリ内の C++ コードに対して(フォークまたは bazel を介して)ビルドできる場合は、PJRT C++ API を実装し、C → C++ ラッパーを使用することもできます。
- ベース PJRT クライアント(および関連する PJRT クラス)を継承する C++ PJRT クライアントを実装します。C++ PJRT クライアントの例としては、pjrt_stream_executor_client.h、tfrt_cpu_pjrt_client.h があります。
- C++ PJRT クライアントの一部ではない、いくつかの C API メソッドを実装します。
- PJRT_Client_Create。擬似コードの例を次に示します(
GetPluginPjRtClient
が上記で実装した C++ PJRT クライアントを返す場合)。
- PJRT_Client_Create。擬似コードの例を次に示します(
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
PJRT_Client_Create は、フレームワークから渡されたオプションを受け取ることができます。GPU クライアントでこの機能がどのように使用されているかについては、こちらの例をご覧ください。
- (省略可)PJRT_TopologyDescription_Create。
- [省略可] PJRT_Plugin_Initialize。これは 1 回限りのプラグイン設定であり、他の関数が呼び出される前にフレームワークによって呼び出されます。
- [省略可] PJRT_Plugin_Attributes。
ラッパーを使用すると、残りの C API を実装する必要はありません。
ステップ 2: GetPjRtApi を実装する
PJRT C API 実装への関数ポインタを含む PJRT_Api*
を返すメソッド GetPjRtApi
を実装する必要があります。次に、ラッパー(pjrt_c_api_cpu.cc に類似)を介して実装する場合の例を示します。
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
ステップ 3: C API の実装をテストする
RegisterPjRtCApiTestFactory を呼び出して、PJRT C API の基本的な動作に関する小規模なテストセットを実行できます。
JAX から PJRT プラグインを使用する方法
ステップ 1: JAX を設定する
JAX を毎晩使用することも、
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
またはソースから JAX をビルドします。
現時点では、jaxlib のバージョンと PJRT C API のバージョンを一致させる必要があります。通常、プラグインをビルドする TF commit と同じ日から jaxlib ナイトリー バージョンを使用するだけで十分です。次に例を示します。
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
ビルド対象の XLA commit の正確なソースから jaxlib をビルドすることもできます(instructions)。
ABI の互換性のサポートはまもなく開始します。
ステップ 2: jax_plugins 名前空間を使用するか、entry_point を設定する
プラグインを JAX で検出する方法は 2 つあります。
- 名前空間パッケージ(ref)を使用する。
jax_plugins
名前空間パッケージの下にグローバルに一意のモジュールを定義します(つまり、jax_plugins
ディレクトリを作成し、その下にモジュールを定義します)。ディレクトリ構造の例を次に示します。
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- パッケージ メタデータを使用する(参照)。pyproject.toml または setup.py でパッケージをビルドする場合は、
jax_plugins
グループの下に完全なモジュール名を指すエントリ ポイントを追加して、プラグイン モジュール名をアドバタイズします。pyproject.toml または setup.py を使用した例を次に示します。
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
オプション 2 を使用して openxla-pjrt-plugin を実装する方法の例を次に示します。https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
ステップ 3: 初期化メソッドを実装する
プラグインを登録するために、Python モジュールに Initialize() メソッドを実装する必要があります。次に例を示します。
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
xla_bridge.register_plugin
の使用方法については、こちらをご覧ください。現在は非公開メソッドです。公開 API は今後リリースされる予定です。
以下の行を実行して、プラグインが登録されていることを確認し、読み込めない場合はエラーを発生させます。
jax.config.update("jax_platforms", "my_plugin")
JAX には複数のバックエンド/プラグインが存在する場合があります。プラグインが確実にデフォルトのバックエンドとして使用されるようにする方法はいくつかあります。
- オプション 1: プログラムの先頭で
jax.config.update("jax_platforms", "my_plugin")
を実行します。 - オプション 2: ENV
JAX_PLATFORMS=my_plugin
を設定する。 - オプション 3: xb.register_plugin を呼び出すときに十分な高い優先度を設定する(デフォルト値は 400 で、他の既存のバックエンドより高い)。
JAX_PLATFORMS=''
の場合にのみ、優先度が最も高いバックエンドが使用されます。JAX_PLATFORMS
のデフォルト値は''
ですが、上書きされることがあります。
JAX でのテスト方法
試すことができる基本的なテストケース:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(プラグインに対して jax 単体テストを実行する手順はまもなく追加されます)。
例: JAX CUDA プラグイン
- ラッパー(pjrt_c_api_gpu.h)を介した PJRT C API の実装。
- パッケージのエントリ ポイントを設定します(setup.py)。
- 初期化メソッド(__init__.py)を実装します。
- CUDA の任意の jax テストでテストできます。```