Integracja wtyczki PJRT

Ten dokument skupia się na zaleceniach dotyczących integracji z PJRT oraz na testowaniu integracji PJRT z JAX.

Jak przeprowadzić integrację z PJRT

Krok 1. Zaimplementuj interfejs PJRT C API

Opcja A: możesz bezpośrednio zaimplementować interfejs API PJRT C.

Opcja B: jeśli możesz kompilować kod C++ w repozytorium xla (za pomocą forkingu lub bazel), możesz też zaimplementować interfejs API PJRT w C++ i użyć oprawy C→C++:

  1. Zaimplementuj klienta PJRT w C++, który dziedziczy z podstawowego klienta PJRT (i powiązanych klas PJRT). Oto kilka przykładów klienta PJRT w C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Wdrożyć kilka metod interfejsu C API, które nie są częścią klienta PJRT w C++:
    • PJRT_Client_Create. Poniżej znajduje się przykładowy kod pseudokodu (zakładając, że GetPluginPjRtClient zwraca zaimplementowanego powyżej klienta PJRT w C++):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Uwaga: PJRT_Client_Create może przyjmować opcje przekazane z ram. Tutaj znajdziesz przykład użycia tej funkcji przez klienta GPU.

Dzięki opakowaniu nie musisz wdrażać pozostałych interfejsów C.

Krok 2. Zaimplementuj funkcję GetPjRtApi

Musisz zaimplementować metodę GetPjRtApi, która zwraca PJRT_Api* zawierający wskaźniki funkcji do implementacji interfejsu PJRT C API. Poniżej podano przykład implementacji za pomocą oprogramowania pośredniczącego (podobnego do pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Krok 3. Testuj implementacje interfejsu C API

Aby uruchomić mały zestaw testów podstawowych zachowań interfejsu PJRT C API, możesz wywołać funkcję RegisterPjRtCApiTestFactory.

Jak używać wtyczki PJRT z JAX

Krok 1. Skonfiguruj JAX

Możesz użyć wersji JAX nightly.

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

lub utworzyć JAX na podstawie źródła.

Na razie musisz dopasować wersję jaxlib do wersji interfejsu PJRT C API. Zwykle wystarczy użyć wersji jaxlib nightly z tego samego dnia, w którym został przesłany commit w TF, na podstawie którego tworzysz wtyczkę.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Możesz też skompilować jaxlib z źródła dokładnie w przypadku commita XLA, na którym skompilujesz (instrukcje).

Wkrótce zaczniemy obsługiwać zgodność ABI.

Krok 2. Użyj przestrzeni nazw jax_plugins lub skonfiguruj entry_point

Istnieją 2 sposoby na wykrycie wtyczki przez JAX.

  1. Używanie pakietów w przestrzeni nazw (ref). Zdefiniuj niepowtarzalny globalnie moduł w ramach pakietu przestrzeni nazw jax_plugins (czyli utwórz katalog jax_plugins i zdefiniuj w nim moduł). Oto przykładowa struktura katalogu:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. za pomocą metadanych pakietu (ref); Jeśli kompilujesz pakiet za pomocą pliku pyproject.toml lub setup.py, podaj nazwę modułu wtyczki, dodając punkt wejścia w grupie jax_plugins, który wskazuje na pełną nazwę modułu. Oto przykład za pomocą pliku pyproject.toml lub setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Oto przykłady implementacji pakietu openxla-pjrt-plugin przy użyciu opcji 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Krok 3. Zaimplementuj metodę initialize().

Aby zarejestrować wtyczkę, musisz zaimplementować metodę initialize() w module Pythona, np.:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Informacje o używaniu xla_bridge.register_plugin znajdziesz tutaj. Obecnie jest to metoda prywatna. W przyszłości udostępnimy publiczny interfejs API.

Aby sprawdzić, czy wtyczka jest zarejestrowana, możesz uruchomić poniższy wiersz. Jeśli nie można jej załadować, zostanie wyświetlony błąd.

jax.config.update("jax_platforms", "my_plugin")

JAX może mieć wiele backendów lub wtyczek. Aby zapewnić, że wtyczka jest używana jako domyślny backend:

  • Opcja 1. Uruchom jax.config.update("jax_platforms", "my_plugin") na początku programu.
  • Opcja 2. Ustaw zmienną środowiskową ENV JAX_PLATFORMS=my_plugin.
  • Opcja 3. Ustaw wystarczająco wysoki priorytet podczas wywołania xb.register_plugin (wartość domyślna to 400, która jest wyższa niż w przypadku innych istniejących backendów). Pamiętaj, że backend o najwyższym priorytecie będzie używany tylko wtedy, gdy JAX_PLATFORMS=''. Wartość domyślna JAX_PLATFORMS to '', ale czasami zostaje ona zastąpiona.

Jak przetestować za pomocą JAX

Oto kilka podstawowych przypadków testowych:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(wkrótce dodamy instrukcje uruchamiania testów jednostkowych jax w przypadku Twojego wtyczka).

Więcej przykładów wtyczek PJRT znajdziesz w artykule Przykłady wtyczek PJRT.