PJRT 플러그인 통합

이 문서에서는 PJRT와 통합하는 방법에 관한 권장사항과 JAX와 PJRT 통합을 테스트하는 방법을 중점적으로 다룹니다.

PJRT와 통합하는 방법

1단계: PJRT C API 인터페이스 구현

옵션 A: PJRT C API를 직접 구현할 수 있습니다.

옵션 B: 포크 또는 bazel을 통해 xla 저장소의 C++ 코드에 대해 빌드할 수 있는 경우 PJRT C++ API를 구현하고 C→C++ 래퍼를 사용할 수도 있습니다.

  1. 기본 PJRT 클라이언트 (및 관련 PJRT 클래스)에서 상속받는 C++ PJRT 클라이언트를 구현합니다. 다음은 C++ PJRT 클라이언트의 예입니다.pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h
  2. C++ PJRT 클라이언트의 일부가 아닌 몇 가지 C API 메서드를 구현합니다.
    • PJRT_Client_Create. 다음은 샘플 가상 코드입니다 (GetPluginPjRtClient가 위에 구현된 C++ PJRT 클라이언트를 반환한다고 가정).
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

참고 PJRT_Client_Create는 프레임워크에서 전달된 옵션을 사용할 수 있습니다. 다음은 GPU 클라이언트가 이 기능을 사용하는 방법의 예입니다.

래퍼를 사용하면 나머지 C API를 구현할 필요가 없습니다.

2단계: GetPjRtApi 구현

PJRT C API 구현에 대한 함수 포인터가 포함된 PJRT_Api*를 반환하는 GetPjRtApi 메서드를 구현해야 합니다. 다음은 래퍼를 통해 구현한다고 가정하는 예입니다 (pjrt_c_api_cpu.cc와 유사).

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

3단계: C API 구현 테스트

RegisterPjRtCApiTestFactory를 호출하여 기본 PJRT C API 동작에 관한 소수의 테스트를 실행할 수 있습니다.

JAX에서 PJRT 플러그인을 사용하는 방법

1단계: JAX 설정

JAX nightly를 사용하거나

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

또는 소스에서 JAX를 빌드합니다.

지금은 jaxlib 버전을 PJRT C API 버전과 일치시켜야 합니다. 일반적으로 플러그인을 빌드하는 TF 커밋과 동일한 날짜의 jaxlib nightly 버전을 사용하면 됩니다.예를 들면 다음과 같습니다.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

빌드하려는 XLA 커밋에서 소스에서 jaxlib를 빌드할 수도 있습니다 (안내).

곧 ABI 호환성을 지원할 예정입니다.

2단계: jax_plugins 네임스페이스 사용 또는 entry_point 설정

JAX에서 플러그인을 검색하는 방법에는 두 가지가 있습니다.

  1. 네임스페이스 패키지 사용 (참조) jax_plugins 네임스페이스 패키지 아래에 전역 고유 모듈을 정의합니다 (즉, jax_plugins 디렉터리를 만들고 그 아래에 모듈을 정의하면 됩니다). 다음은 디렉터리 구조의 예입니다.
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. 패키지 메타데이터 (참조) 사용 pyproject.toml 또는 setup.py를 통해 패키지를 빌드하는 경우 전체 모듈 이름을 가리키는 진입점을 jax_plugins 그룹 아래에 포함하여 플러그인 모듈 이름을 광고합니다. 다음은 pyproject.toml 또는 setup.py를 통한 예입니다.
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

다음은 옵션 2를 사용하여 openxla-pjrt-plugin이 구현되는 예입니다. https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

3단계: initialize() 메서드 구현

플러그인을 등록하려면 Python 모듈에서 initialize() 메서드를 구현해야 합니다. 예를 들면 다음과 같습니다.

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin 사용 방법은 여기를 참고하세요. 현재 비공개 메서드입니다. 향후 공개 API가 출시될 예정입니다.

아래 줄을 실행하여 플러그인이 등록되었는지 확인하고 로드할 수 없는 경우 오류를 발생시킬 수 있습니다.

jax.config.update("jax_platforms", "my_plugin")

JAX에는 여러 백엔드/플러그인이 있을 수 있습니다. 플러그인이 기본 백엔드로 사용되도록 하는 몇 가지 옵션이 있습니다.

  • 옵션 1: 프로그램 시작 부분에서 jax.config.update("jax_platforms", "my_plugin")를 실행합니다.
  • 옵션 2: ENV JAX_PLATFORMS=my_plugin를 설정합니다.
  • 옵션 3: xb.register_plugin을 호출할 때 충분히 높은 우선순위를 설정합니다 (기본값은 400으로 다른 기존 백엔드보다 높음). 가장 높은 우선순위의 백엔드는 JAX_PLATFORMS=''인 경우에만 사용됩니다. JAX_PLATFORMS의 기본값은 ''이지만 덮어쓰는 경우도 있습니다.

JAX로 테스트하는 방법

다음과 같은 기본 테스트 사례를 시도해 보세요.

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(곧 플러그인에 대해 jax 단위 테스트를 실행하는 방법에 관한 안내가 추가될 예정입니다.)

PJRT 플러그인의 더 많은 예는 PJRT 예시를 참고하세요.