範例:JAX CUDA 外掛程式
- 透過包裝函式實作 PJRT C API (pjrt_c_api_gpu.h)。
- 設定套件的進入點 (setup.py)。
- 實作 initialize() 方法 (__init__.py)。
- 可透過任何 CUDA 的 jax 測試進行測試。
架構實作
在架構端使用 PJRT 的參考資料,以便與 PJRT 裝置連接:
- JAX
- jax-ml/jax 會透過
xla_client
API 與 PJRT API 互動
- jax-ml/jax 會透過
- GoMLX
- ZML
- PJRT API 包裝函式 pjrt.zig
- 載入 PJRT 外掛程式 context.zig
- 與 PJRT 緩衝區互動 buffer.zig
- 透過 PJRT module.zig 執行模組
硬體實作
- 完整整合外掛程式 (PJRT+MLIR+XLA):
- 輕量整合外掛程式 (PJRT+MLIR):
- StableHLO 參考解譯器外掛程式 (以 MLIR 為基礎的 C++ 外掛程式,會在 devlabs 之後連結)
- Tenstorrent-XLA 外掛程式 (以 MLIR 為基礎,C 外掛程式)