Hintergrund
PJRT ist die einheitliche Device API, die wir der ML-Umgebung hinzufügen möchten. Die langfristige Vision ist folgende: (1) Frameworks (JAX, TF usw.) rufen PJRT auf, das gerätespezifische Implementierungen hat, die für die Frameworks nicht transparent sind; (2) jedes Gerät konzentriert sich auf die Implementierung von PJRT APIs und kann für die Frameworks undurchsichtig sein.
Dieses Dokument konzentriert sich auf die Empfehlungen zur Integration mit PJRT und zum Testen der PJRT-Integration mit JAX.
Einbindung in PJRT
Schritt 1: PJRT C API-Schnittstelle implementieren
Option A: Sie können die PJRT C API direkt implementieren.
Option B: Wenn Sie Builds mit C++-Code im xla-Repository erstellen können (über Forking oder Baizel), können Sie auch die PJRT C++ API implementieren und den Wrapper „C→C++“ verwenden:
- Implementieren Sie einen C++ PJRT-Client, der den PJRT-Basisclient (und verwandte PJRT-Klassen) übernimmt. Hier sind einige Beispiele für einen C++-PJRT-Client: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementieren Sie einige C API-Methoden, die nicht Teil des C++ PJRT-Clients sind:
- PJRT_Client_Create. Unten sehen Sie ein Beispiel für einen Pseudocode (sofern
GetPluginPjRtClient
einen oben implementierten C++ PJRT-Client zurückgibt):
- PJRT_Client_Create. Unten sehen Sie ein Beispiel für einen Pseudocode (sofern
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Hinweis: PJRT_Client_Create kann Optionen übernehmen, die vom Framework übergeben werden. Hier finden Sie ein Beispiel dafür, wie ein GPU-Client diese Funktion verwendet.
- [Optional] PJRT_TopologyDescription_Create.
- [Optional] PJRT_Plugin_Initialize (PJRT_Plugin_Initialize) Dies ist ein einmaliges Plug-in-Setup, das vom Framework vor allen anderen Funktionen aufgerufen wird.
- [Optional] PJRT_Plugin_Attributes
Mit dem Wrapper müssen die verbleibenden C APIs nicht implementiert werden.
Schritt 2: GetPjRtApi implementieren
Sie müssen die Methode GetPjRtApi
implementieren, die ein PJRT_Api*
mit Funktionsverweisen auf PJRT C API-Implementierungen zurückgibt. Im Folgenden finden Sie ein Beispiel mit einer Implementierung über Wrapper (ähnlich wie pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Schritt 3: C API-Implementierungen testen
Sie können RegisterPjRtCApiTestFactory aufrufen, um eine kleine Reihe von Tests für grundlegende PJRT C API-Verhalten auszuführen.
PJRT-Plug-in von JAX verwenden
Schritt 1: JAX einrichten
Sie können entweder jede Nacht JAX verwenden
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
oder JAX aus dem Quellcode erstellen.
Vorerst müssen Sie die jaxlib-Version der PJRT C API-Version zuordnen. In der Regel reicht es aus, eine jaxlib-nächtliche Version vom selben Tag wie der TF-Commit zu verwenden, für den Sie das Plug-in erstellen, z.B.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Sie können auch genau mit dem XLA-Commit ein jaxlib aus dem Quellcode erstellen (instructions).
Die ABI-Kompatibilität wird bald unterstützt.
Schritt 2: Namespace „jax_plugins“ verwenden oder „entry_point“ einrichten
Es gibt zwei Möglichkeiten, wie Ihr Plug-in von JAX erkannt werden kann.
- Verwenden von Namespace-Paketen (ref) Definieren Sie ein global eindeutiges Modul unter dem
jax_plugins
-Namespace-Paket (d.h. erstellen Sie einfach einjax_plugins
-Verzeichnis und definieren Sie das Modul darunter). Hier ist ein Beispiel für eine Verzeichnisstruktur:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Paketmetadaten verwenden (ref) Wenn Sie ein Paket mit pyproject.toml oder „setup.py“ erstellen, geben Sie Ihren Plug-in-Modulnamen an, indem Sie in der Gruppe
jax_plugins
einen Einstiegspunkt hinzufügen, der auf Ihren vollständigen Modulnamen verweist. Hier ist ein Beispiel mit pyproject.toml oder setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Die folgenden Beispiele zeigen, wie das openxla-pjrt-Plug-in mit Option 2 implementiert wird: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Schritt 3: Methode „initialisieren()“ implementieren
Zum Registrieren des Plug-ins müssen Sie in Ihrem Python-Modul eine Methode "initialisieren" implementieren. Beispiel:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Weitere Informationen zur Verwendung von xla_bridge.register_plugin
findest du hier. Dies ist derzeit eine private Methode. Zukünftig wird eine öffentliche API veröffentlicht.
Sie können die folgende Zeile ausführen, um zu überprüfen, ob das Plug-in registriert ist, und einen Fehler ausgeben, wenn es nicht geladen werden kann.
jax.config.update("jax_platforms", "my_plugin")
JAX kann mehrere Back-Ends/Plug-ins haben. Es gibt einige Möglichkeiten, um sicherzustellen, dass Ihr Plug-in als Standard-Back-End verwendet wird:
- Option 1: Führen Sie
jax.config.update("jax_platforms", "my_plugin")
zu Beginn des Programms aus. - Option 2: Legen Sie ENV
JAX_PLATFORMS=my_plugin
fest. - Option 3: Legen Sie eine ausreichend hohe Priorität beim Aufrufen von xb.register_plugin fest (der Standardwert ist 400, was höher ist als bei anderen vorhandenen Back-Ends). Das Back-End mit der höchsten Priorität wird nur bei
JAX_PLATFORMS=''
verwendet. Der Standardwert vonJAX_PLATFORMS
ist''
, wird aber manchmal überschrieben.
Testen mit JAX
Einige einfache Testfälle, die Sie ausprobieren können:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
Wir werden in Kürze eine Anleitung zum Ausführen der Jax-Einheitentests für Ihr Plug-in hinzufügen.
Beispiel: JAX CUDA-Plug-in
- PJRT C API-Implementierung über Wrapper (pjrt_c_api_gpu.h)
- Richten Sie den Einstiegspunkt für das Paket ein (setup.py).
- Implementieren Sie die Methode initial() (__init__.py).
- Kann mit allen Jax-Tests für CUDA getestet werden. ```