Example: JAX CUDA plugin
- PJRT C API implementation through wrapper (pjrt_c_api_gpu.h).
- Set up the entry point for the package (setup.py).
- Implement an initialize() method (__init__.py).
- Can be tested with any jax tests for CUDA.
Frameworks Implementations
Some references for using PJRT on the framework side, to interface with PJRT devices:
- JAX
- jax-ml/jax
interacts with PJRT APIs via the
xla_client
APIs
- jax-ml/jax
interacts with PJRT APIs via the
- GoMLX
- ZML
- PJRT API wrapper pjrt.zig
- Load PJRT Plugin context.zig
- Interacting with PJRT Buffers buffer.zig
- Execute a module via PJRT module.zig
Hardware Implementations
- Full integration plugins (PJRT+MLIR+XLA):
- Light integration plugins (PJRT+MLIR):
- StableHLO Reference Interpreter plugin (MLIR-based, C++ plugin, to be linked after devlabs)
- Tenstorrent-XLA plugin (MLIR-based, C plugin)