例: JAX CUDA プラグイン
- ラッパーを介した PJRT C API の実装(pjrt_c_api_gpu.h)。
- パッケージのエントリ ポイント(setup.py)を設定します。
- initialize() メソッドを実装します(__init__.py)。
- CUDA の任意の jax テストでテストできます。
フレームワークの実装
フレームワーク側で PJRT を使用して PJRT デバイスとインターフェースを構築するための参考資料:
- JAX
- jax-ml/jax は、
xla_client
API を介して PJRT API とやり取りします。
- jax-ml/jax は、
- GoMLX
- ZML
- PJRT API ラッパー pjrt.zig
- PJRT プラグイン context.zig を読み込む
- PJRT バッファ buffer.zig の操作
- PJRT の module.zig を介してモジュールを実行する
ハードウェアの実装
- 完全統合プラグイン(PJRT+MLIR+XLA):
- ライト インテグレーション プラグイン(PJRT+MLIR):
- StableHLO リファレンス インタープリタ プラグイン(MLIR ベースの C++ プラグイン。devlabs 後にリンク)
- Tenstorrent-XLA プラグイン(MLIR ベースの C プラグイン)