Wprowadzenie
PJRT to jednolity interfejs Device API, który chcemy dodać do ekosystemu systemów uczących się. Zgodnie z długoterminową wizją: (1) platformy (JAX, TF itp.) będą wywoływały PJRT, które mają implementacje nieprzezroczyste dla tych platform; (2) każde urządzenie skupia się na implementowaniu interfejsów API PJRT i może być nieprzejrzyste dla tych platform.
Ten dokument koncentruje się na rekomendacjach dotyczących integracji z PJRT i testowaniu integracji PJRT z JAX.
Jak przeprowadzić integrację z PJRT
Krok 1. Zaimplementuj interfejs PJRT C API
Opcja A: możesz wdrożyć interfejs PJRT C API bezpośrednio.
Opcja B: jeśli możesz tworzyć kompilacje na podstawie kodu C++ w repozytorium Xla (za pomocą rozwidlenia lub bazela), możesz też wdrożyć interfejs PJRT C++ API i użyć otoki C→C++:
- Zaimplementuj dziedziczenie klienta PJRT w języku C++ z klienta podstawowego PJRT (i powiązanych klas PJRT). Oto kilka przykładów klienta PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Zaimplementuj kilka metod interfejsu API w języku C, które nie są częścią klienta PJRT w C++:
- PJRT_Client_Create. Poniżej znajdziesz przykładowy pseudokod (przy założeniu, że
GetPluginPjRtClient
zwraca klienta PJRT C++ zaimplementowanego powyżej):
- PJRT_Client_Create. Poniżej znajdziesz przykładowy pseudokod (przy założeniu, że
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Uwaga: PJRT_Client_Create może pobierać opcje przekazane z platformy. Tutaj znajdziesz przykład użycia tej funkcji przez klienta GPU.
- [Opcjonalny] PJRT_TopologyDescription_Create.
- [Opcjonalny] PJRT_Plugin_Initialize. Jest to jednorazowa konfiguracja wtyczki, która zostanie wywołana przez platformę przed wywołaniem jakichkolwiek innych funkcji.
- [Opcjonalny] PJRT_Plugin_Attributes.
Dzięki kodowi nie musisz implementować pozostałych interfejsów API C.
Krok 2. Wdróż GetPjRtApi
Musisz zaimplementować metodę GetPjRtApi
, która zwraca wartość PJRT_Api*
zawierającą wskaźniki funkcji do implementacji interfejsu PJRT C API. Poniżej znajdziesz przykład przy założeniu, że chcesz wdrożyć je za pomocą wrapper (podobny do pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Krok 3. Przetestuj implementacje interfejsu C API
Możesz wywołać metodę RegisterPjRtCApiTestFactory, aby przeprowadzić niewielki zestaw testów pod kątem podstawowych zachowań interfejsu PJRT C API.
Jak używać wtyczki PJRT z JAX
Krok 1. Skonfiguruj JAX
Możesz użyć aplikacji JAX co noc.
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
lub utwórz plik JAX ze źródła.
Na razie musisz dopasować wersję jaxlib do wersji interfejsu PJRT C API. Zwykle wystarczy nocna wersja jaxlib z tego samego dnia, pod kątem którego tworzysz wtyczkę TF, np.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Plik jaxlib możesz też utworzyć ze źródła dokładnie w zatwierdzeniach XLA, na podstawie którego kompilujesz (instructions).
Wkrótce zaczniemy obsługiwać zgodność z interfejsem ABI.
Krok 2. Użyj przestrzeni nazw jax_Plugins lub skonfiguruj element entry_point
JAX może wykryć Twoją wtyczkę na 2 sposoby.
- Używanie pakietów przestrzeni nazw (ref). Zdefiniuj unikalny moduł globalnie w ramach pakietu przestrzeni nazw
jax_plugins
(np. utwórz katalogjax_plugins
i zdefiniuj moduł pod nim). Przykładowa struktura katalogów:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Używanie metadanych pakietu (ref). Jeśli tworzysz pakiet przy użyciu pliku pyproject.toml lub setup.py, reklamuj nazwę modułu wtyczki, uwzględniając w grupie
jax_plugins
punkt wejścia, który wskazuje pełną nazwę modułu. Oto przykład z użyciem pliku pyproject.toml lub setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Oto przykłady wdrożenia wtyczki openxla-pjrt-plugin przy użyciu opcji 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120
Krok 3. Zaimplementuj metodę initial()
Aby zarejestrować wtyczkę, musisz zaimplementować w module Pythona metodę inicjowania(), np.:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Jak korzystać z usługi xla_bridge.register_plugin
, znajdziesz tutaj. Obecnie jest to metoda prywatna. W przyszłości zostanie udostępniony publiczny interfejs API.
Możesz uruchomić poniższy wiersz, aby sprawdzić, czy wtyczka jest zarejestrowana, i zgłosić błąd, jeśli nie uda się jej wczytać.
jax.config.update("jax_platforms", "my_plugin")
JAX może mieć wiele backendów/wtyczek. Masz kilka możliwości używania wtyczki jako domyślnego backendu:
- Opcja 1. Uruchom
jax.config.update("jax_platforms", "my_plugin")
na początku programu. - Opcja 2. Ustaw ENV
JAX_PLATFORMS=my_plugin
. - Opcja 3: ustaw odpowiednio wysoki priorytet podczas wywoływania metody xb.register_plugin (wartość domyślna to 400, która jest wyższa niż w przypadku innych istniejących backendów). Backend o najwyższym priorytecie będzie używany tylko wtedy, gdy
JAX_PLATFORMS=''
. Wartość domyślna parametruJAX_PLATFORMS
to''
, ale czasami zostaje zastąpiona.
Testowanie za pomocą JAX
Oto kilka podstawowych przykładów testowych, które możesz wypróbować:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(wkrótce dodamy do Twojej wtyczki instrukcje dotyczące testowania jednostek Jax).
Przykład: wtyczka JAX CUDA
- Wdrożenie interfejsu PJRT C API za pomocą kodu (pjrt_c_api_gpu.h).
- Skonfiguruj punkt wejścia pakietu (setup.py).
- Zaimplementuj metodę initial() (__init__.py).
- Można ją przetestować za pomocą dowolnych testów Jax pod kątem CUDA. ```