示例:JAX CUDA 插件
- 通过封装容器实现 PJRT C API (pjrt_c_api_gpu.h)。
- 设置软件包的入口点 (setup.py)。
- 实现 initialize() 方法 (__init__.py)。
- 可以使用适用于 CUDA 的任何 jax 测试进行测试。
框架实现
以下是一些关于在框架端使用 PJRT 与 PJRT 设备交互的参考文档:
- JAX
- jax-ml/jax 通过
xla_client
API 与 PJRT API 交互
- jax-ml/jax 通过
- GoMLX
- ZML
- PJRT API 封装容器 pjrt.zig
- 加载 PJRT 插件 context.zig
- 与 PJRT 缓冲区 buffer.zig 交互
- 通过 PJRT module.zig 执行模块
硬件实现
- 完整集成插件 (PJRT+MLIR+XLA):
- 轻量集成插件 (PJRT+MLIR):
- StableHLO 参考解释器插件(基于 MLIR 的 C++ 插件,将在 devlabs 之后关联)
- Tenstorrent-XLA 插件(基于 MLIR 的 C 插件)