Ejemplo: Plugin CUDA de JAX
- Implementación de la API de PJRT C a través de un wrapper (pjrt_c_api_gpu.h).
- Configura el punto de entrada del paquete (setup.py).
- Implementa un método initialize() (__init__.py).
- Se puede probar con cualquier prueba de Jax para CUDA.
Implementaciones de frameworks
Estas son algunas referencias para usar PJRT en el framework y establecer una interfaz con dispositivos PJRT:
- JAX
- jax-ml/jax interactúa con las APIs de PJRT a través de las APIs de
xla_client
.
- jax-ml/jax interactúa con las APIs de PJRT a través de las APIs de
- GoMLX
- ZML
- Wrapper de la API de PJRT pjrt.zig
- Carga el complemento PJRT context.zig.
- Interacción con búferes de PJRT buffer.zig
- Ejecuta un módulo a través de module.zig de PJRT
Implementaciones de hardware
- Complementos de integración completa (PJRT+MLIR+XLA):
- Complementos de integración ligeros (PJRT+MLIR):
- Complemento de intérprete de referencia de StableHLO (complemento de C++, basado en MLIR, que se vinculará después de devlabs)
- Complemento Tenstorrent-XLA (complemento C basado en MLIR)