예: JAX CUDA 플러그인
- 래퍼를 통한 PJRT C API 구현 (pjrt_c_api_gpu.h)
- 패키지의 진입점 (setup.py)을 설정합니다.
- initialize() 메서드 (__init__.py)를 구현합니다.
- CUDA용 jax 테스트로 테스트할 수 있습니다.
프레임워크 구현
PJRT 기기와 상호작용하기 위해 프레임워크 측에서 PJRT를 사용하는 방법에 관한 몇 가지 참조 자료는 다음과 같습니다.
- JAX
- jax-ml/jax는
xla_client
API를 통해 PJRT API와 상호작용합니다.
- jax-ml/jax는
- GoMLX
- ZML
- PJRT API 래퍼 pjrt.zig
- PJRT 플러그인 context.zig 로드
- PJRT 버퍼 buffer.zig와 상호작용
- PJRT module.zig를 통해 모듈 실행
하드웨어 구현
- 전체 통합 플러그인 (PJRT+MLIR+XLA):
- 가벼운 통합 플러그인 (PJRT+MLIR):
- StableHLO 참조 인터프리터 플러그인(MLIR 기반 C++ 플러그인, devlabs 후 연결)
- Tenstorrent-XLA 플러그인(MLIR 기반, C 플러그인)