Integrazione del plug-in PJRT

Questo documento è incentrato sui consigli su come eseguire l'integrazione con PJRT e su come testare l'integrazione di PJRT con JAX.

Come eseguire l'integrazione con PJRT

Passaggio 1: implementa l'interfaccia API C PJRT

Opzione A: puoi implementare direttamente l'API C PJRT.

Opzione B: se riesci a eseguire il build in base al codice C++ nel repo xla (tramite forking o bazel), puoi anche implementare l'API PJRT C++ e utilizzare il wrapper C→C++:

  1. Implementa un client PJRT C++ che eredita dal client PJRT di base (e dalle classi PJRT correlate). Ecco alcuni esempi di client PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implementa alcuni metodi dell'API C che non fanno parte del client PJRT C++:
    • PJRT_Client_Create. Di seguito è riportato un pseudocodice di esempio (supponendo che GetPluginPjRtClient restituisca un client PJRT C++) implementato sopra:
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Nota: PJRT_Client_Create può accettare le opzioni passate dal framework. Qui è riportato un esempio di come un client GPU utilizza questa funzionalità.

Con il wrapper, non è necessario implementare le API C rimanenti.

Passaggio 2: implementa GetPjRtApi

Devi implementare un metodo GetPjRtApi che restituisce un PJRT_Api* contenente puntatori di funzione alle implementazioni dell'API C PJRT. Di seguito è riportato un esempio che presuppone l'implementazione tramite wrapper (simile a pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Passaggio 3: testa le implementazioni dell'API C

Puoi chiamare RegisterPjRtCApiTestFactory per eseguire un piccolo insieme di test per i comportamenti di base dell'API C PJRT.

Come utilizzare un plug-in PJRT da JAX

Passaggio 1: configura JAX

Puoi utilizzare la versione nightly di JAX

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

oppure crea JAX dal codice sorgente.

Per il momento, devi associare la versione jaxlib alla versione dell'API C PJRT. Di solito è sufficiente utilizzare una versione notturna di jaxlib dello stesso giorno del commit di TF su cui stai creando il plug-in, ad esempio

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Puoi anche creare un file jaxlib dal codice sorgente esattamente nel commit XLA in base al quale stai eseguendo la compilazione (istruzioni).

A breve inizieremo a supportare la compatibilità ABI.

Passaggio 2: utilizza lo spazio dei nomi jax_plugins o configura entry_point

Esistono due opzioni per consentire a JAX di rilevare il plug-in.

  1. Utilizzo dei pacchetti dello spazio dei nomi (ref). Definisci un modulo univoco a livello globale nel pacchetto dello spazio dei nomi jax_plugins (ovvero crea una directory jax_plugins e definisci il modulo sotto di essa). Ecco un esempio di struttura di directory:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Utilizzo dei metadati del pacchetto (ref). Se crei un pacchetto tramite pyproject.toml o setup.py, pubblicizza il nome del modulo del plug-in includendo un punto di contatto nel gruppo jax_plugins che rimandi al nome completo del modulo. Ecco un esempio tramite pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Ecco alcuni esempi di come openxla-pjrt-plugin viene implementato utilizzando l'opzione 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Passaggio 3: implementa un metodo initialize()

Per registrare il plug-in, devi implementare un metodo initialize() nel modulo Python, ad esempio:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Consulta questa pagina per scoprire come utilizzare xla_bridge.register_plugin. Al momento è un metodo privato. In futuro verrà rilasciata un'API pubblica.

Puoi eseguire la riga riportata di seguito per verificare che il plug-in sia registrato e generare un errore se non è possibile caricarlo.

jax.config.update("jax_platforms", "my_plugin")

JAX può avere più backend/plug-in. Esistono alcune opzioni per assicurarti che il plug-in venga utilizzato come backend predefinito:

  • Opzione 1: esegui jax.config.update("jax_platforms", "my_plugin") all'inizio del programma.
  • Opzione 2: imposta ENV JAX_PLATFORMS=my_plugin.
  • Opzione 3: imposta una priorità sufficientemente elevata quando chiami xb.register_plugin (il valore predefinito è 400, superiore a quello degli altri backend esistenti). Tieni presente che il backend con la priorità più alta verrà utilizzato solo quando JAX_PLATFORMS=''. Il valore predefinito di JAX_PLATFORMS è '', ma a volte viene sovrascritto.

Come eseguire test con JAX

Ecco alcuni scenari di test di base da provare:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

A breve aggiungeremo le istruzioni per eseguire i test di unità Jax sul tuo plug-in.

Per altri esempi di plug-in PJRT, consulta la pagina Esempi di PJRT.