PJRT-Plug-in-Integration

Hintergrund

PJRT ist die einheitliche Device API, die wir der ML-Umgebung hinzufügen möchten. Die langfristige Vision ist folgende: (1) Frameworks (JAX, TF usw.) rufen PJRT auf, das gerätespezifische Implementierungen hat, die für die Frameworks nicht transparent sind; (2) jedes Gerät konzentriert sich auf die Implementierung von PJRT APIs und kann für die Frameworks undurchsichtig sein.

Dieses Dokument konzentriert sich auf die Empfehlungen zur Integration mit PJRT und zum Testen der PJRT-Integration mit JAX.

Einbindung in PJRT

Schritt 1: PJRT C API-Schnittstelle implementieren

Option A: Sie können die PJRT C API direkt implementieren.

Option B: Wenn Sie Builds mit C++-Code im xla-Repository erstellen können (über Forking oder Baizel), können Sie auch die PJRT C++ API implementieren und den Wrapper „C→C++“ verwenden:

  1. Implementieren Sie einen C++ PJRT-Client, der den PJRT-Basisclient (und verwandte PJRT-Klassen) übernimmt. Hier sind einige Beispiele für einen C++-PJRT-Client: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implementieren Sie einige C API-Methoden, die nicht Teil des C++ PJRT-Clients sind:
    • PJRT_Client_Create. Unten sehen Sie ein Beispiel für einen Pseudocode (sofern GetPluginPjRtClient einen oben implementierten C++ PJRT-Client zurückgibt):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Hinweis: PJRT_Client_Create kann Optionen übernehmen, die vom Framework übergeben werden. Hier finden Sie ein Beispiel dafür, wie ein GPU-Client diese Funktion verwendet.

Mit dem Wrapper müssen die verbleibenden C APIs nicht implementiert werden.

Schritt 2: GetPjRtApi implementieren

Sie müssen die Methode GetPjRtApi implementieren, die ein PJRT_Api* mit Funktionsverweisen auf PJRT C API-Implementierungen zurückgibt. Im Folgenden finden Sie ein Beispiel mit einer Implementierung über Wrapper (ähnlich wie pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Schritt 3: C API-Implementierungen testen

Sie können RegisterPjRtCApiTestFactory aufrufen, um eine kleine Reihe von Tests für grundlegende PJRT C API-Verhalten auszuführen.

PJRT-Plug-in von JAX verwenden

Schritt 1: JAX einrichten

Sie können entweder jede Nacht JAX verwenden

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

oder JAX aus dem Quellcode erstellen.

Vorerst müssen Sie die jaxlib-Version der PJRT C API-Version zuordnen. In der Regel reicht es aus, eine jaxlib-nächtliche Version vom selben Tag wie der TF-Commit zu verwenden, für den Sie das Plug-in erstellen, z.B.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Sie können auch genau mit dem XLA-Commit ein jaxlib aus dem Quellcode erstellen (instructions).

Die ABI-Kompatibilität wird bald unterstützt.

Schritt 2: Namespace „jax_plugins“ verwenden oder „entry_point“ einrichten

Es gibt zwei Möglichkeiten, wie Ihr Plug-in von JAX erkannt werden kann.

  1. Verwenden von Namespace-Paketen (ref) Definieren Sie ein global eindeutiges Modul unter dem jax_plugins-Namespace-Paket (d.h. erstellen Sie einfach ein jax_plugins-Verzeichnis und definieren Sie das Modul darunter). Hier ist ein Beispiel für eine Verzeichnisstruktur:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Paketmetadaten verwenden (ref) Wenn Sie ein Paket mit pyproject.toml oder „setup.py“ erstellen, geben Sie Ihren Plug-in-Modulnamen an, indem Sie in der Gruppe jax_plugins einen Einstiegspunkt hinzufügen, der auf Ihren vollständigen Modulnamen verweist. Hier ist ein Beispiel mit pyproject.toml oder setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Die folgenden Beispiele zeigen, wie das openxla-pjrt-Plug-in mit Option 2 implementiert wird: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Schritt 3: Methode „initialisieren()“ implementieren

Zum Registrieren des Plug-ins müssen Sie in Ihrem Python-Modul eine Methode "initialisieren" implementieren. Beispiel:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Weitere Informationen zur Verwendung von xla_bridge.register_plugin findest du hier. Dies ist derzeit eine private Methode. Zukünftig wird eine öffentliche API veröffentlicht.

Sie können die folgende Zeile ausführen, um zu überprüfen, ob das Plug-in registriert ist, und einen Fehler ausgeben, wenn es nicht geladen werden kann.

jax.config.update("jax_platforms", "my_plugin")

JAX kann mehrere Back-Ends/Plug-ins haben. Es gibt einige Möglichkeiten, um sicherzustellen, dass Ihr Plug-in als Standard-Back-End verwendet wird:

  • Option 1: Führen Sie jax.config.update("jax_platforms", "my_plugin") zu Beginn des Programms aus.
  • Option 2: Legen Sie ENV JAX_PLATFORMS=my_plugin fest.
  • Option 3: Legen Sie eine ausreichend hohe Priorität beim Aufrufen von xb.register_plugin fest (der Standardwert ist 400, was höher ist als bei anderen vorhandenen Back-Ends). Das Back-End mit der höchsten Priorität wird nur bei JAX_PLATFORMS='' verwendet. Der Standardwert von JAX_PLATFORMS ist '', wird aber manchmal überschrieben.

Testen mit JAX

Einige einfache Testfälle, die Sie ausprobieren können:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

Wir werden in Kürze eine Anleitung zum Ausführen der Jax-Einheitentests für Ihr Plug-in hinzufügen.

Beispiel: JAX CUDA-Plug-in

  1. PJRT C API-Implementierung über Wrapper (pjrt_c_api_gpu.h)
  2. Richten Sie den Einstiegspunkt für das Paket ein (setup.py).
  3. Implementieren Sie die Methode initial() (__init__.py).
  4. Kann mit allen Jax-Tests für CUDA getestet werden. ```