PJRT 예

예: JAX CUDA 플러그인

  1. 래퍼를 통한 PJRT C API 구현 (pjrt_c_api_gpu.h)
  2. 패키지의 진입점 (setup.py)을 설정합니다.
  3. initialize() 메서드 (__init__.py)를 구현합니다.
  4. CUDA용 jax 테스트로 테스트할 수 있습니다.

프레임워크 구현

PJRT 기기와 상호작용하기 위해 프레임워크 측에서 PJRT를 사용하는 방법에 관한 몇 가지 참조 자료는 다음과 같습니다.

하드웨어 구현