Ce document se concentre sur les recommandations concernant l'intégration à PJRT et sur la façon de tester l'intégration de PJRT à JAX.
Intégrer avec PJRT
Étape 1: Implémentez l'interface API C PJRT
Option A: Vous pouvez implémenter directement l'API C PJRT.
Option B: Si vous pouvez compiler avec du code C++ dans le dépôt xla (via le forking ou Bazel), vous pouvez également implémenter l'API C++ PJRT et utiliser le wrapper C→C++:
- Implémentez un client PJRT C++ héritant du client PJRT de base (et des classes PJRT associées). Voici quelques exemples de clients PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implémentez quelques méthodes d'API C qui ne font pas partie du client PJRT C++ :
- PJRT_Client_Create. Vous trouverez ci-dessous un exemple de pseudo-code (en supposant que
GetPluginPjRtClient
renvoie un client PJRT C++ implémenté ci-dessus):
- PJRT_Client_Create. Vous trouverez ci-dessous un exemple de pseudo-code (en supposant que
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Remarque : PJRT_Client_Create peut accepter des options transmises par le framework. Voici un exemple d'utilisation de cette fonctionnalité par un client GPU.
- [Facultatif] PJRT_TopologyDescription_Create
- [Facultatif] PJRT_Plugin_Initialize. Il s'agit d'une configuration de plug-in unique, qui sera appelée par le framework avant l'appel d'autres fonctions.
- [Facultatif] PJRT_Plugin_Attributes
Avec le wrapper, vous n'avez pas besoin d'implémenter les autres API C.
Étape 2: Implémentez GetPjRtApi
Vous devez implémenter une méthode GetPjRtApi
qui renvoie un PJRT_Api*
contenant des pointeurs de fonction vers les implémentations de l'API C PJRT. Vous trouverez ci-dessous un exemple supposant une implémentation via un wrapper (similaire à pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Étape 3: Tester les implémentations d'API C
Vous pouvez appeler RegisterPjRtCApiTestFactory pour exécuter un petit ensemble de tests des comportements de base de l'API C PJRT.
Utiliser un plug-in PJRT depuis JAX
Étape 1: Configurer JAX
Vous pouvez utiliser la version nocturne de JAX
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
ou compiler JAX à partir de la source.
Pour le moment, vous devez faire correspondre la version de jaxlib à la version de l'API C PJRT. Il est généralement suffisant d'utiliser une version nocturne de jaxlib du même jour que le commit TF sur lequel vous compilez votre plug-in, par exemple :
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Vous pouvez également créer un jaxlib à partir de la source exactement à partir du commit XLA que vous utilisez pour la compilation (instructions).
Nous allons bientôt prendre en charge la compatibilité ABI.
Étape 2: Utilisez l'espace de noms jax_plugins ou configurez entry_point
Il existe deux options pour que votre plug-in soit détecté par JAX.
- En utilisant des packages d'espace de noms (référence). Définissez un module unique au niveau mondial sous le package d'espace de noms
jax_plugins
(c'est-à-dire, créez simplement un répertoirejax_plugins
et définissez votre module en dessous). Voici un exemple de structure de répertoires:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- À l'aide des métadonnées du package (référence). Si vous créez un package via pyproject.toml ou setup.py, annoncez le nom du module de votre plug-in en incluant un point d'entrée sous le groupe
jax_plugins
qui pointe vers le nom complet du module. Voici un exemple via pyproject.toml ou setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Voici des exemples d'implémentation de openxla-pjrt-plugin à l'aide de l'option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120.
Étape 3: Implémentez une méthode initialize()
Vous devez implémenter une méthode initialize() dans votre module Python pour enregistrer le plug-in, par exemple:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Pour savoir comment utiliser xla_bridge.register_plugin
, cliquez ici. Il s'agit actuellement d'une méthode privée. Une API publique sera publiée à l'avenir.
Vous pouvez exécuter la ligne ci-dessous pour vérifier que le plug-in est enregistré et générer une erreur s'il ne peut pas être chargé.
jax.config.update("jax_platforms", "my_plugin")
JAX peut avoir plusieurs backends/plug-ins. Vous avez plusieurs options pour vous assurer que votre plug-in est utilisé comme backend par défaut:
- Option 1: exécuter
jax.config.update("jax_platforms", "my_plugin")
au début du programme. - Option 2: définir ENV
JAX_PLATFORMS=my_plugin
. - Option 3: définir une priorité suffisamment élevée lors de l'appel de xb.register_plugin (la valeur par défaut est 400, qui est supérieure à celle des autres backends existants). Notez que le backend ayant la priorité la plus élevée n'est utilisé que lorsque
JAX_PLATFORMS=''
. La valeur par défaut deJAX_PLATFORMS
est''
, mais elle est parfois écrasée.
Tester avec JAX
Voici quelques scénarios de test de base à essayer:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Nous ajouterons bientôt des instructions pour exécuter les tests unitaires jax sur votre plug-in.)
Pour obtenir d'autres exemples de plug-ins PJRT, consultez la section Exemples PJRT.