Exemple: Plug-in CUDA JAX
- Implémentation de l'API C PJRT via un wrapper (pjrt_c_api_gpu.h).
- Configurez le point d'entrée du package (setup.py).
- Implémentez une méthode initialize() (__init__.py).
- Peut être testé avec n'importe quel test jax pour CUDA.
Implémentations de frameworks
Quelques références pour utiliser PJRT côté framework, afin d'interagir avec les appareils PJRT:
- JAX
- jax-ml/jax interagit avec les API PJRT via les API
xla_client
.
- jax-ml/jax interagit avec les API PJRT via les API
- GoMLX
- ZML
- Encapsuleur d'API PJRT pjrt.zig
- Charger le plug-in PJRT context.zig
- Interagir avec les tampons PJRT buffer.zig
- Exécuter un module via PJRT module.zig
Implémentations matérielles
- Plugins d'intégration complète (PJRT+MLIR+XLA) :
- Plug-ins d'intégration légers (PJRT+MLIR) :
- Plug-in de l'interprète de référence StableHLO (plug-in C++ basé sur MLIR, à associer après les devlabs)
- Plug-in Tenstorrent-XLA (plug-in C basé sur le MLIR)