PJRT 示例

示例:JAX CUDA 插件

  1. 通过封装容器实现 PJRT C API (pjrt_c_api_gpu.h)。
  2. 设置软件包的入口点 (setup.py)。
  3. 实现 initialize() 方法 (__init__.py)。
  4. 可以使用适用于 CUDA 的任何 jax 测试进行测试。

框架实现

以下是一些关于在框架端使用 PJRT 与 PJRT 设备交互的参考文档:

硬件实现