PJRT プラグインの統合

このドキュメントでは、PJRT との統合方法と、JAX との PJRT 統合をテストする方法について説明します。

PJRT と統合する方法

ステップ 1: PJRT C API インターフェースを実装する

オプション A: PJRT C API を直接実装できます。

オプション B: xla リポジトリの C++ コードに対して(フォークまたは bazel を使用して)ビルドできる場合は、PJRT C++ API を実装して C→C++ ラッパーを使用することもできます。

  1. ベースの PJRT クライアント(および関連する PJRT クラス)を継承する C++ PJRT クライアントを実装します。C++ PJRT クライアントの例を次に示します。pjrt_stream_executor_client.htfrt_cpu_pjrt_client.h
  2. C++ PJRT クライアントに含まれていない C API メソッドをいくつか実装します。
    • PJRT_Client_Create。以下にサンプルの疑似コードを示します(GetPluginPjRtClient が上で実装した C++ PJRT クライアントを返すと仮定しています)。
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

注: PJRT_Client_Create は、フレームワークから渡されたオプションを使用できます。GPU クライアントがこの機能をどのように使用するかの例をこちらに示します。

ラッパーを使用すると、残りの C API を実装する必要はありません。

ステップ 2: GetPjRtApi を実装する

PJRT C API 実装への関数ポインタを含む PJRT_Api* を返すメソッド GetPjRtApi を実装する必要があります。ラッパーを介した実装を前提とした例を次に示します(pjrt_c_api_cpu.cc に類似)。

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

ステップ 3: C API の実装をテストする

RegisterPjRtCApiTestFactory を呼び出して、基本的な PJRT C API 動作に関する一連のテストを行うことができます。

JAX から PJRT プラグインを使用する方法

ステップ 1: JAX を設定する

JAX ナイトリー版を使用するか、

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

または、JAX をソースからビルドします。

現時点では、jaxlib のバージョンを PJRT C API のバージョンと一致させる必要があります。通常は、プラグインをビルドする TF コミットと同じ日の jaxlib ナイトリー バージョンを使用すれば十分です。

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

ビルド対象の XLA コミットとまったく同じソースから jaxlib をビルドすることもできます(手順)。

まもなく ABI 互換性のサポートを開始します。

ステップ 2: jax_plugins Namespace を使用するか、entry_point を設定する

JAX によってプラグインが検出される方法は 2 つあります。

  1. Namespace パッケージを使用する(参照)。jax_plugins 名前空間パッケージの下にグローバルに一意のモジュールを定義します(つまり、jax_plugins ディレクトリを作成し、その下にモジュールを定義します)。ディレクトリ構造の例を次に示します。
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. パッケージ メタデータを使用する(参照)。pyproject.toml または setup.py を使用してパッケージをビルドする場合は、jax_plugins グループにエントリ ポイントを含めて、モジュールの完全な名前を参照し、プラグイン モジュール名をアドバタイズします。pyproject.toml または setup.py を使用した例を次に示します。
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

オプション 2 を使用して openxla-pjrt-plugin を実装する方法の例: https://github.com/openxla/openxla-pjrt-plugin/pull/119,https://github.com/openxla/openxla-pjrt-plugin/pull/120

ステップ 3: initialize() メソッドを実装する

プラグインを登録するには、Python モジュールに initialize() メソッドを実装する必要があります。次に例を示します。

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin の使用方法については、こちらをご覧ください。現在は非公開メソッドです。今後、公開 API がリリースされる予定です。

次の行を実行して、プラグインが登録されていることを確認します。読み込めない場合はエラーを発生させます。

jax.config.update("jax_platforms", "my_plugin")

JAX には複数のバックエンド/プラグインが存在する場合があります。プラグインをデフォルトのバックエンドとして使用するには、いくつかの方法があります。

  • オプション 1: プログラムの開始時に jax.config.update("jax_platforms", "my_plugin") を実行します。
  • オプション 2: ENV JAX_PLATFORMS=my_plugin を設定します。
  • オプション 3: xb.register_plugin を呼び出すときに十分な優先度を設定します(デフォルト値は 400 で、他の既存のバックエンドよりも高くなります)。優先度が最も高いバックエンドは、JAX_PLATFORMS='' の場合にのみ使用されます。JAX_PLATFORMS のデフォルト値は '' ですが、上書きされることもあります。

JAX でテストする方法

試すことができる基本的なテストケースは次のとおりです。

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(まもなく、プラグインに対して jax 単体テストを実行する手順を追加する予定です)。

PJRT プラグインのその他の例については、PJRT の例をご覧ください。