Contesto
PJRT è l'API uniforme del dispositivo che vogliamo aggiungere all'ecosistema ML. La visione a lungo termine è che: (1) i framework (JAX, TF, ecc.) chiameranno PJRT, che ha implementazioni specifiche del dispositivo opache ai framework; (2) ciascun dispositivo si concentra sull’implementazione delle API PJRT e può essere opaco rispetto ai framework.
Questo documento si concentra sulle raccomandazioni su come eseguire l'integrazione con PJRT e su come testare l'integrazione di PJRT con JAX.
Come eseguire l'integrazione con PJRT
Passaggio 1: implementa l'interfaccia dell'API PJRT C
Opzione A: puoi implementare direttamente l'API PJRT C.
Opzione B: se sei in grado di creare sulla base del codice C++ nel repository xla (tramite forking o bazel), puoi anche implementare l'API PJRT C++ e utilizzare il wrapper C→C++:
- Implementare un client PJRT C++ che eredita dal client PJRT di base (e dalle relative classi PJRT). Ecco alcuni esempi di client PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementa alcuni metodi dell'API C che non fanno parte del client PJRT C++:
- PJRT_Client_Create. Di seguito è riportato un esempio di pseudocodice (supponendo che
GetPluginPjRtClient
restituisca un client PJRT C++ implementato sopra):
- PJRT_Client_Create. Di seguito è riportato un esempio di pseudocodice (supponendo che
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Nota: PJRT_Client_Create può utilizzare opzioni trasferite dal framework. Qui è riportato un esempio di come un client GPU utilizza questa funzionalità.
- [Facoltativo] PJRT_TopologyDescription_Create.
- [Facoltativo] PJRT_Plugin_Initialize. Si tratta di una configurazione di plug-in da eseguire una sola volta, che verrà chiamata dal framework prima di chiamare qualsiasi altra funzione.
- [Facoltativo] PJRT_Plugin_Attributes.
Con il wrapper, non è necessario implementare le API C rimanenti.
Passaggio 2: implementa GetPjRtApi
Devi implementare un metodo GetPjRtApi
che restituisca un puntatore a funzione contenente PJRT_Api*
alle implementazioni dell'API C PJRT. Di seguito è riportato un esempio in cui si presuppone l'implementazione tramite wrapper (simile a pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Passaggio 3: testa le implementazioni dell'API C
Puoi chiamare RegisterPjRtCApiTestFactory per eseguire un piccolo insieme di test sui comportamenti di base dell'API PJRT C.
Come utilizzare un plug-in PJRT di JAX
Passaggio 1: configura JAX
Puoi usare JAX di notte
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
oppure crea JAX dal codice sorgente.
Per ora, devi associare la versione jaxlib alla versione dell'API PJRT C. Generalmente è sufficiente utilizzare una versione notturna jaxlib dello stesso giorno del commit TF per cui stai creando il plug-in, ad esempio:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Puoi anche creare un file jaxlib dal codice sorgente esattamente al commit XLA per cui stai creando (instructions).
Inizieremo a supportare la compatibilità con ABI.
Passaggio 2: utilizza lo spazio dei nomi jax_plugins o configura entry_point
Esistono due opzioni per consentire a JAX di trovare il tuo plug-in.
- Utilizzo di pacchetti dello spazio dei nomi (ref). Definisci un modulo univoco a livello globale nel pacchetto dello spazio dei nomi
jax_plugins
(ad esempio, crea una directoryjax_plugins
e definisci il modulo in basso). Ecco un esempio di struttura di directory:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Utilizzo dei metadati del pacchetto (ref). Se crei un pacchetto tramite pyproject.toml o setup.py, pubblicizza il nome del modulo del plug-in includendo un punto di ingresso sotto il gruppo
jax_plugins
che rimanda al nome completo del modulo. Ecco un esempio tramite pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Ecco alcuni esempi di come openxla-pjrt-plugin viene implementato utilizzando l'opzione 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Passaggio 3: implementa un metodo startize()
Per registrare il plug-in, devi implementare un metodo startize() nel modulo Python, ad esempio:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Consulta questa pagina per informazioni sull'utilizzo di xla_bridge.register_plugin
. Attualmente è un metodo privato. In futuro verrà rilasciata un'API pubblica.
Puoi eseguire la riga seguente per verificare che il plug-in sia registrato e segnalare un errore se non può essere caricato.
jax.config.update("jax_platforms", "my_plugin")
JAX potrebbe avere più backend/plugin. Esistono alcune opzioni per assicurarsi che il plug-in venga utilizzato come backend predefinito:
- Opzione 1: esegui
jax.config.update("jax_platforms", "my_plugin")
all'inizio del programma. - Opzione 2: imposta ENV
JAX_PLATFORMS=my_plugin
. - Opzione 3: imposta una priorità sufficientemente elevata quando chiami xb.register_plugin (il valore predefinito è 400, maggiore di quello degli altri backend esistenti). Tieni presente che il backend con la massima priorità verrà utilizzato solo quando
JAX_PLATFORMS=''
. Il valore predefinito diJAX_PLATFORMS
è''
, ma a volte viene sovrascritto.
Come eseguire il test con JAX
Alcuni scenari di test di base da provare:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
Presto aggiungeremo le istruzioni per eseguire i test delle unità jax sul tuo plug-in.
Esempio: plug-in JAX CUDA
- Implementazione dell'API PJRT C tramite wrapper (pjrt_c_api_gpu.h).
- Configura il punto di ingresso per il pacchetto (setup.py).
- Implementa un metodo startize() (__init__.py).
- Può essere testato con qualsiasi test jax per CUDA. ```