PJRT 플러그인 통합

배경

PJRT는 ML 생태계에 추가하려는 단일 Device API입니다. 장기적인 비전은 (1) 프레임워크 (JAX, TF 등)가 PJRT를 호출하고 이는 프레임워크에 불투명한 기기별 구현을 포함하고 (2) 각 기기가 PJRT API 구현에 중점을 두고 프레임워크에 불투명할 수 있다는 것입니다.

이 문서에서는 PJRT와 통합하는 방법 및 PJRT와 JAX의 통합을 테스트하는 방법에 대한 권장사항을 중점적으로 설명합니다.

PJRT와 통합하는 방법

1단계: PJRT C API 인터페이스 구현

옵션 A: PJRT C API를 직접 구현할 수 있습니다.

옵션 B: xla 저장소의 C++ 코드에 대해 포크 또는 bazel을 통해 빌드할 수 있는 경우 PJRT C++ API를 구현하고 C→C++ 래퍼를 사용할 수도 있습니다.

  1. 기본 PJRT 클라이언트 (및 관련 PJRT 클래스)에서 상속하는 C++ PJRT 클라이언트를 구현합니다. 다음은 C++ PJRT 클라이언트의 몇 가지 예입니다.pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h
  2. C++ PJRT 클라이언트의 일부가 아닌 몇 가지 C API 메서드를 구현합니다.
    • PJRT_Client_Create와 같은 메서드를 사용합니다. 다음은 의사 코드 샘플입니다 (GetPluginPjRtClient가 위에서 구현된 C++ PJRT 클라이언트를 반환한다고 가정).
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

PJRT_Client_Create는 프레임워크에서 전달된 옵션을 사용할 수 있습니다. 여기에서 GPU 클라이언트에서 이 기능을 사용하는 방법의 예를 확인할 수 있습니다.

래퍼를 사용하면 나머지 C API를 구현할 필요가 없습니다.

2단계: GetPjRtApi 구현

PJRT C API 구현에 관한 함수 포인터가 포함된 PJRT_Api*를 반환하는 GetPjRtApi 메서드를 구현해야 합니다. 다음은 래퍼를 통해 구현한다고 가정한 예입니다 (pjrt_c_api_cpu.cc와 유사).

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

3단계: C API 구현 테스트

RegisterPjRtCApiTestFactory를 호출하여 기본 PJRT C API 동작에 관한 소규모 테스트를 실행할 수 있습니다.

JAX에서 PJRT 플러그인을 사용하는 방법

1단계: JAX 설정

JAX Nightly를 사용하거나

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

또는 소스에서 JAX를 빌드합니다.

지금은 jaxlib 버전을 PJRT C API 버전과 일치시켜야 합니다. 일반적으로 플러그인을 빌드하는 TF 커밋과 같은 날의 jaxlib 나이틀리 버전만 사용해도 충분합니다(예:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

빌드 중인 XLA 커밋과 정확히 일치하는 소스에서 jaxlib를 빌드할 수도 있습니다 (instructions).

곧 ABI 호환성 지원을 시작할 예정입니다.

2단계: jax_plugins 네임스페이스 사용 또는 진입점 설정

JAX에서 플러그인을 검색할 수 있는 두 가지 옵션이 있습니다.

  1. 네임스페이스 패키지 사용 (ref) jax_plugins 네임스페이스 패키지 아래에 전역적으로 고유한 모듈을 정의합니다. 즉, jax_plugins 디렉터리를 만들고 그 아래에 있는 모듈을 정의하기만 하면 됩니다. 다음은 디렉터리 구조의 예입니다.
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 패키지 메타데이터 사용 (참조) pyproject.toml 또는 setup.py를 통해 패키지를 빌드하는 경우 전체 모듈 이름을 가리키는 jax_plugins 그룹 아래에 진입점을 포함하여 플러그인 모듈 이름을 알립니다. pyproject.toml 또는 setup.py를 통한 예는 다음과 같습니다.
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

옵션 2를 사용하여 openxla-pjrt-plugin을 구현하는 방법의 예는 https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120입니다.

3단계: initialize() 메서드 구현

플러그인을 등록하려면 Python 모듈에서 initialize() 메서드를 구현해야 합니다. 예를 들면 다음과 같습니다.

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin 사용 방법은 여기를 참고하세요. 현재 비공개 메서드입니다. 향후 공개 API가 출시될 예정입니다.

아래 줄을 실행하여 플러그인이 등록되었는지 확인하고 로드할 수 없는 경우 오류를 발생시킬 수 있습니다.

jax.config.update("jax_platforms", "my_plugin")

JAX에는 여러 백엔드/플러그인이 있을 수 있습니다. 플러그인이 기본 백엔드로 사용되도록 하는 몇 가지 옵션이 있습니다.

  • 옵션 1: 프로그램 시작 부분에서 jax.config.update("jax_platforms", "my_plugin")를 실행합니다.
  • 옵션 2: ENV JAX_PLATFORMS=my_plugin 설정
  • 옵션 3: xb.register_plugin을 호출할 때 충분히 높은 우선순위를 설정합니다. 기본값은 400이며 다른 기존 백엔드보다 높습니다. 우선순위가 가장 높은 백엔드는 JAX_PLATFORMS=''인 경우에만 사용됩니다. JAX_PLATFORMS의 기본값은 ''이지만 간혹 덮어쓰기될 수 있습니다.

JAX로 테스트하는 방법

시도해 볼 만한 몇 가지 기본 테스트 사례:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(곧 플러그인에 대한 jax 단위 테스트 실행 안내를 추가할 예정입니다.)

예: JAX CUDA 플러그인

  1. 래퍼 (pjrt_c_api_gpu.h)를 통한 PJRT C API 구현
  2. 패키지의 진입점 (setup.py)을 설정합니다.
  3. initialize() 메서드 (__init__.py)를 구현합니다.
  4. CUDA에 대한 모든 jax 테스트로 테스트할 수 있습니다. ```