In diesem Dokument geht es hauptsächlich um Empfehlungen zur Einbindung in PJRT und zum Testen der PJRT-Integration mit JAX.
PJRT einbinden
Schritt 1: PJRT C API-Schnittstelle implementieren
Option A: Sie können die PJRT C API direkt implementieren.
Option B: Wenn Sie über Forking oder Bazel einen Build mit C++-Code im xla-Repository ausführen können, können Sie auch die PJRT C++ API implementieren und den C-zu-C++-Wrapper verwenden:
- Implementiere einen C++-PJRT-Client, der vom PJRT-Basisclient (und den zugehörigen PJRT-Klassen) erbt. Hier einige Beispiele für C++-PJRT-Clients: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementiere einige C API-Methoden, die nicht Teil des C++ PJRT-Clients sind:
- PJRT_Client_Create. Unten findest du ein Pseudocode-Beispiel. Dabei wird davon ausgegangen, dass
GetPluginPjRtClient
einen oben implementierten C++-PJRT-Client zurückgibt:
- PJRT_Client_Create. Unten findest du ein Pseudocode-Beispiel. Dabei wird davon ausgegangen, dass
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Hinweis: PJRT_Client_Create kann Optionen annehmen, die vom Framework übergeben werden. Hier ist ein Beispiel dafür, wie ein GPU-Client diese Funktion verwendet.
- [Optional] PJRT_TopologyDescription_Create.
- [Optional] PJRT_Plugin_Initialize. Dies ist eine einmalige Plugin-Einrichtung, die vom Framework aufgerufen wird, bevor andere Funktionen aufgerufen werden.
- [Optional] PJRT_Plugin_Attributes.
Mit dem Wrapper müssen Sie die verbleibenden C-APIs nicht implementieren.
Schritt 2: GetPjRtApi implementieren
Sie müssen eine Methode GetPjRtApi
implementieren, die eine PJRT_Api*
mit Funktionszeigern auf PJRT-C-API-Implementierungen zurückgibt. Unten findest du ein Beispiel für die Implementierung über einen Wrapper (ähnlich wie pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Schritt 3: C API-Implementierungen testen
Sie können RegisterPjRtCApiTestFactory aufrufen, um eine kleine Reihe von Tests für grundlegende PJRT C API-Verhaltensweisen auszuführen.
PJRT-Plug-in von JAX verwenden
Schritt 1: JAX einrichten
Sie können entweder JAX nightly
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
oder JAX aus dem Quellcode erstellen.
Derzeit müssen Sie die jaxlib-Version mit der PJRT C API-Version abgleichen. In der Regel reicht es aus, eine jaxlib-Nachtversion vom selben Tag wie der TF-Commit zu verwenden, anhand dessen Sie Ihr Plug-in erstellen, z.B.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Sie können auch eine jaxlib aus der Quelle genau mit dem XLA-Commit erstellen, auf das Sie aufbauen (Anleitung).
Wir werden die ABI-Kompatibilität bald unterstützen.
Schritt 2: Namespace „jax_plugins“ verwenden oder „entry_point“ einrichten
Es gibt zwei Möglichkeiten, wie Ihr Plug-in von JAX gefunden werden kann.
- Namespace-Pakete verwenden (ref) Definieren Sie ein global eindeutiges Modul unter dem Namespace-Paket
jax_plugins
. Erstellen Sie dazu einfach einjax_plugins
-Verzeichnis und definieren Sie Ihr Modul darunter. Hier ist eine Beispielverzeichnisstruktur:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Paketmetadaten verwenden (ref) Wenn Sie ein Paket über pyproject.toml oder setup.py erstellen, geben Sie den Namen Ihres Plug-in-Moduls an, indem Sie unter der Gruppe
jax_plugins
einen Einstiegspunkt angeben, der auf den vollständigen Modulnamen verweist. Hier ein Beispiel über pyproject.toml oder setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Hier sind Beispiele dafür, wie das openxla-pjrt-plugin mit Option 2 implementiert wird: https://github.com/openxla/openxla-pjrt-plugin/pull/119,, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Schritt 3: Methode „initialize()“ implementieren
Sie müssen in Ihrem Python-Modul die Methode „initialize()“ implementieren, um das Plug-in zu registrieren. Beispiel:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Informationen zur Verwendung von xla_bridge.register_plugin
Sie ist derzeit eine private Methode. Eine öffentliche API wird in Zukunft veröffentlicht.
Mit der folgenden Zeile können Sie prüfen, ob das Plug-in registriert ist, und einen Fehler ausgeben, wenn es nicht geladen werden kann.
jax.config.update("jax_platforms", "my_plugin")
JAX kann mehrere Back-Ends/Plug-ins haben. Es gibt einige Möglichkeiten, dafür zu sorgen, dass dein Plug-in als Standard-Backend verwendet wird:
- Option 1:
jax.config.update("jax_platforms", "my_plugin")
zu Beginn des Programms ausführen - Option 2: ENV
JAX_PLATFORMS=my_plugin
festlegen - Option 3: Lege beim Aufrufen von xb.register_plugin eine ausreichend hohe Priorität fest. Der Standardwert ist 400, was höher ist als bei anderen vorhandenen Backends. Das Backend mit der höchsten Priorität wird nur verwendet, wenn
JAX_PLATFORMS=''
. Der Standardwert vonJAX_PLATFORMS
ist''
, wird aber manchmal überschrieben.
Mit JAX testen
Einige grundlegende Testfälle:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Wir fügen demnächst eine Anleitung zum Ausführen der Jax-Unit-Tests für dein Plug-in hinzu.)
Weitere Beispiele für PJRT-Plug-ins finden Sie unter PJRT-Beispiele.