PJRT plugin integration

This doc focuses on the recommendations about how to integrate with PJRT, and how to test PJRT integration with JAX.

How to integrate with PJRT

Step 1: Implement PJRT C API interface

Option A: You can implement the PJRT C API directly.

Option B: If you're able to build against C++ code in the xla repo (via forking or bazel), you can also implement the PJRT C++ API and use the C→C++ wrapper:

  1. Implement a C++ PJRT client inheriting from the base PJRT client (and related PJRT classes). Here are some examples of C++ PJRT client: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implement a few C API methods that are not part of C++ PJRT client:
    • PJRT_Client_Create. Below is some sample pseudo code (assuming GetPluginPjRtClient returns a C++ PJRT client implemented above):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Note PJRT_Client_Create can take options passed from the framework. Here is an example of how a GPU client uses this feature.

With the wrapper, you do not need to implement the remaining C APIs.

Step 2: Implement GetPjRtApi

You need to implement a method GetPjRtApi which returns a PJRT_Api* containing function pointers to PJRT C API implementations. Below is an example assuming implementing through wrapper (similar to pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Step 3: Test C API implementations

You can call RegisterPjRtCApiTestFactory to run a small set of tests for basic PJRT C API behaviors.

How to use a PJRT plugin from JAX

Step 1: Set up JAX

You can either use JAX nightly

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

or build JAX from source.

For now, you need to match the jaxlib version with the PJRT C API version. It's usually sufficient to use a jaxlib nightly version from the same day as the TF commit you're building your plugin against, e.g.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

You can also build a jaxlib from source at exactly the XLA commit you're building against (instructions).

We will start supporting ABI compatibility soon.

Step 2: Use jax_plugins namespace or set up entry_point

There are two options for your plugin to be discovered by JAX.

  1. Using namespace packages (ref). Define a globally unique module under the jax_plugins namespace package (i.e. just create a jax_plugins directory and define your module below it). Here is an example directory structure:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Using package metadata (ref). If building a package via pyproject.toml or setup.py, advertise your plugin module name by including an entry-point under the jax_plugins group which points to your full module name. Here is an example via pyproject.toml or setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Here are examples of how openxla-pjrt-plugin is implemented using Option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Step 3: Implement an initialize() method

You need to implement an initialize() method in your python module to register the plugin, for example:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Please refer to here about how to use xla_bridge.register_plugin. It is currently a private method. A public API will be released in the future.

You can run the line below to verify that the plugin is registered and raise an error if it can't be loaded.

jax.config.update("jax_platforms", "my_plugin")

JAX may have multiple backends/plugins. There are a few options to ensure your plugin is used as the default backend:

  • Option 1: run jax.config.update("jax_platforms", "my_plugin") in the beginning of the program.
  • Option 2: set ENV JAX_PLATFORMS=my_plugin.
  • Option 3: set a high enough priority when calling xb.register_plugin (the default value is 400 which is higher than other existing backends). Note the backend with highest priority will be used only when JAX_PLATFORMS=''. The default value of JAX_PLATFORMS is '' but sometimes it will get overwritten.

How to test with JAX

Some basic test cases to try:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(We'll add instructions for running the jax unit tests against your plugin soon!)

For more examples of PJRT plugins see PJRT Examples.