Background
PJRT is the uniform Device API that we want to add to the ML ecosystem. The long term vision is that: (1) frameworks (JAX, TF, etc.) will call PJRT, which has device-specific implementations that are opaque to the frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque to the frameworks.
This doc focuses on the recommendations about how to integrate with PJRT, and how to test PJRT integration with JAX.
How to integrate with PJRT
Step 1: Implement PJRT C API interface
Option A: You can implement the PJRT C API directly.
Option B: If you're able to build against C++ code in the xla repo (via forking or bazel), you can also implement the PJRT C++ API and use the C→C++ wrapper:
- Implement a C++ PJRT client inheriting from the base PJRT client (and related PJRT classes). Here are some examples of C++ PJRT client: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implement a few C API methods that are not part of C++ PJRT client:
- PJRT_Client_Create. Below is some sample pseudo code (assuming
GetPluginPjRtClient
returns a C++ PJRT client implemented above):
- PJRT_Client_Create. Below is some sample pseudo code (assuming
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Note PJRT_Client_Create can take options passed from the framework. Here is an example of how a GPU client uses this feature.
- [Optional] PJRT_TopologyDescription_Create.
- [Optional] PJRT_Plugin_Initialize. This is a one-time plugin setup, which will be called by the framework before any other functions are called.
- [Optional] PJRT_Plugin_Attributes.
With the wrapper, you do not need to implement the remaining C APIs.
Step 2: Implement GetPjRtApi
You need to implement a method GetPjRtApi
which returns a PJRT_Api*
containing function pointers to PJRT C API implementations. Below is an example assuming implementing through wrapper (similar to pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Step 3: Test C API implementations
You can call RegisterPjRtCApiTestFactory to run a small set of tests for basic PJRT C API behaviors.
How to use a PJRT plugin from JAX
Step 1: Set up JAX
You can either use JAX nightly
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
For now, you need to match the jaxlib version with the PJRT C API version. It's usually sufficient to use a jaxlib nightly version from the same day as the TF commit you're building your plugin against, e.g.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
You can also build a jaxlib from source at exactly the XLA commit you're building against (instructions).
We will start supporting ABI compatibility soon.
Step 2: Use jax_plugins namespace or set up entry_point
There are two options for your plugin to be discovered by JAX.
- Using namespace packages (ref). Define a globally unique module under the
jax_plugins
namespace package (i.e. just create ajax_plugins
directory and define your module below it). Here is an example directory structure:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Using package metadata (ref). If building a package via pyproject.toml or setup.py, advertise your plugin module name by including an entry-point under the
jax_plugins
group which points to your full module name. Here is an example via pyproject.toml or setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Here are examples of how openxla-pjrt-plugin is implemented using Option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Step 3: Implement an initialize() method
You need to implement an initialize() method in your python module to register the plugin, for example:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Please refer to here about how to use xla_bridge.register_plugin
. It is currently a private method. A public API will be released in the future.
You can run the line below to verify that the plugin is registered and raise an error if it can't be loaded.
jax.config.update("jax_platforms", "my_plugin")
JAX may have multiple backends/plugins. There are a few options to ensure your plugin is used as the default backend:
- Option 1: run
jax.config.update("jax_platforms", "my_plugin")
in the beginning of the program. - Option 2: set ENV
JAX_PLATFORMS=my_plugin
. - Option 3: set a high enough priority when calling xb.register_plugin (the default value is 400 which is higher than other existing backends). Note the backend with highest priority will be used only when
JAX_PLATFORMS=''
. The default value ofJAX_PLATFORMS
is''
but sometimes it will get overwritten.
How to test with JAX
Some basic test cases to try:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(We'll add instructions for running the jax unit tests against your plugin soon!)
Example: JAX CUDA plugin
- PJRT C API implementation through wrapper (pjrt_c_api_gpu.h).
- Set up the entry point for the package (setup.py).
- Implement an initialize() method (__init__.py).
- Can be tested with any jax tests for CUDA. ```