Questo documento è incentrato sui consigli su come eseguire l'integrazione con PJRT e su come testare l'integrazione di PJRT con JAX.
Come eseguire l'integrazione con PJRT
Passaggio 1: implementa l'interfaccia API C PJRT
Opzione A: puoi implementare direttamente l'API C PJRT.
Opzione B: se riesci a eseguire il build in base al codice C++ nel repo xla (tramite forking o bazel), puoi anche implementare l'API PJRT C++ e utilizzare il wrapper C→C++:
- Implementa un client PJRT C++ che eredita dal client PJRT di base (e dalle classi PJRT correlate). Ecco alcuni esempi di client PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementa alcuni metodi dell'API C che non fanno parte del client PJRT C++:
- PJRT_Client_Create. Di seguito è riportato un pseudocodice di esempio (supponendo che
GetPluginPjRtClient
restituisca un client PJRT C++) implementato sopra:
- PJRT_Client_Create. Di seguito è riportato un pseudocodice di esempio (supponendo che
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Nota: PJRT_Client_Create può accettare le opzioni passate dal framework. Qui è riportato un esempio di come un client GPU utilizza questa funzionalità.
- [Facoltativo] PJRT_TopologyDescription_Create.
- [Facoltativo] PJRT_Plugin_Initialize. Si tratta di una configurazione una tantum del plug-in, che verrà chiamata dal framework prima di qualsiasi altra funzione.
- [Facoltativo] PJRT_Plugin_Attributes.
Con il wrapper, non è necessario implementare le API C rimanenti.
Passaggio 2: implementa GetPjRtApi
Devi implementare un metodo GetPjRtApi
che restituisce un PJRT_Api*
contenente puntatori di funzione alle implementazioni dell'API C PJRT. Di seguito è riportato un esempio che presuppone l'implementazione tramite wrapper (simile a pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Passaggio 3: testa le implementazioni dell'API C
Puoi chiamare RegisterPjRtCApiTestFactory per eseguire un piccolo insieme di test per i comportamenti di base dell'API C PJRT.
Come utilizzare un plug-in PJRT da JAX
Passaggio 1: configura JAX
Puoi utilizzare la versione nightly di JAX
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
oppure crea JAX dal codice sorgente.
Per il momento, devi associare la versione jaxlib alla versione dell'API C PJRT. Di solito è sufficiente utilizzare una versione notturna di jaxlib dello stesso giorno del commit di TF su cui stai creando il plug-in, ad esempio
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Puoi anche creare un file jaxlib dal codice sorgente esattamente nel commit XLA in base al quale stai eseguendo la compilazione (istruzioni).
A breve inizieremo a supportare la compatibilità ABI.
Passaggio 2: utilizza lo spazio dei nomi jax_plugins o configura entry_point
Esistono due opzioni per consentire a JAX di rilevare il plug-in.
- Utilizzo dei pacchetti dello spazio dei nomi (ref). Definisci un modulo univoco a livello globale nel pacchetto dello spazio dei nomi
jax_plugins
(ovvero crea una directoryjax_plugins
e definisci il modulo sotto di essa). Ecco un esempio di struttura di directory:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Utilizzo dei metadati del pacchetto (ref). Se crei un pacchetto tramite pyproject.toml o setup.py, pubblicizza il nome del modulo del plug-in includendo un punto di contatto nel gruppo
jax_plugins
che rimandi al nome completo del modulo. Ecco un esempio tramite pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Ecco alcuni esempi di come openxla-pjrt-plugin viene implementato utilizzando l'opzione 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Passaggio 3: implementa un metodo initialize()
Per registrare il plug-in, devi implementare un metodo initialize() nel modulo Python, ad esempio:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Consulta questa pagina per scoprire come utilizzare xla_bridge.register_plugin
. Al momento è un metodo privato. In futuro verrà rilasciata un'API pubblica.
Puoi eseguire la riga riportata di seguito per verificare che il plug-in sia registrato e generare un errore se non è possibile caricarlo.
jax.config.update("jax_platforms", "my_plugin")
JAX può avere più backend/plug-in. Esistono alcune opzioni per assicurarti che il plug-in venga utilizzato come backend predefinito:
- Opzione 1: esegui
jax.config.update("jax_platforms", "my_plugin")
all'inizio del programma. - Opzione 2: imposta ENV
JAX_PLATFORMS=my_plugin
. - Opzione 3: imposta una priorità sufficientemente elevata quando chiami xb.register_plugin (il valore predefinito è 400, superiore a quello degli altri backend esistenti). Tieni presente che il backend con la priorità più alta verrà utilizzato solo quando
JAX_PLATFORMS=''
. Il valore predefinito diJAX_PLATFORMS
è''
, ma a volte viene sovrascritto.
Come eseguire test con JAX
Ecco alcuni scenari di test di base da provare:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
A breve aggiungeremo le istruzioni per eseguire i test di unità Jax sul tuo plug-in.
Per altri esempi di plug-in PJRT, consulta la pagina Esempi di PJRT.