PJRT Examples

Example: JAX CUDA plugin

  1. PJRT C API implementation through wrapper (pjrt_c_api_gpu.h).
  2. Set up the entry point for the package (setup.py).
  3. Implement an initialize() method (__init__.py).
  4. Can be tested with any jax tests for CUDA.

Frameworks Implementations

Some references for using PJRT on the framework side, to interface with PJRT devices:

Hardware Implementations