Intégration du plug-in PJRT

Ce document se concentre sur les recommandations concernant l'intégration à PJRT et sur la façon de tester l'intégration de PJRT à JAX.

Intégrer avec PJRT

Étape 1: Implémentez l'interface API C PJRT

Option A: Vous pouvez implémenter directement l'API C PJRT.

Option B: Si vous pouvez compiler avec du code C++ dans le dépôt xla (via le forking ou Bazel), vous pouvez également implémenter l'API C++ PJRT et utiliser le wrapper C→C++:

  1. Implémentez un client PJRT C++ héritant du client PJRT de base (et des classes PJRT associées). Voici quelques exemples de clients PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implémentez quelques méthodes d'API C qui ne font pas partie du client PJRT C++ :
    • PJRT_Client_Create. Vous trouverez ci-dessous un exemple de pseudo-code (en supposant que GetPluginPjRtClient renvoie un client PJRT C++ implémenté ci-dessus):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Remarque : PJRT_Client_Create peut accepter des options transmises par le framework. Voici un exemple d'utilisation de cette fonctionnalité par un client GPU.

Avec le wrapper, vous n'avez pas besoin d'implémenter les autres API C.

Étape 2: Implémentez GetPjRtApi

Vous devez implémenter une méthode GetPjRtApi qui renvoie un PJRT_Api* contenant des pointeurs de fonction vers les implémentations de l'API C PJRT. Vous trouverez ci-dessous un exemple supposant une implémentation via un wrapper (similaire à pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Étape 3: Tester les implémentations d'API C

Vous pouvez appeler RegisterPjRtCApiTestFactory pour exécuter un petit ensemble de tests des comportements de base de l'API C PJRT.

Utiliser un plug-in PJRT depuis JAX

Étape 1: Configurer JAX

Vous pouvez utiliser la version nocturne de JAX

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

ou compiler JAX à partir de la source.

Pour le moment, vous devez faire correspondre la version de jaxlib à la version de l'API C PJRT. Il est généralement suffisant d'utiliser une version nocturne de jaxlib du même jour que le commit TF sur lequel vous compilez votre plug-in, par exemple :

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Vous pouvez également créer un jaxlib à partir de la source exactement à partir du commit XLA que vous utilisez pour la compilation (instructions).

Nous allons bientôt prendre en charge la compatibilité ABI.

Étape 2: Utilisez l'espace de noms jax_plugins ou configurez entry_point

Il existe deux options pour que votre plug-in soit détecté par JAX.

  1. En utilisant des packages d'espace de noms (référence). Définissez un module unique au niveau mondial sous le package d'espace de noms jax_plugins (c'est-à-dire, créez simplement un répertoire jax_plugins et définissez votre module en dessous). Voici un exemple de structure de répertoires:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. À l'aide des métadonnées du package (référence). Si vous créez un package via pyproject.toml ou setup.py, annoncez le nom du module de votre plug-in en incluant un point d'entrée sous le groupe jax_plugins qui pointe vers le nom complet du module. Voici un exemple via pyproject.toml ou setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Voici des exemples d'implémentation de openxla-pjrt-plugin à l'aide de l'option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120.

Étape 3: Implémentez une méthode initialize()

Vous devez implémenter une méthode initialize() dans votre module Python pour enregistrer le plug-in, par exemple:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Pour savoir comment utiliser xla_bridge.register_plugin, cliquez ici. Il s'agit actuellement d'une méthode privée. Une API publique sera publiée à l'avenir.

Vous pouvez exécuter la ligne ci-dessous pour vérifier que le plug-in est enregistré et générer une erreur s'il ne peut pas être chargé.

jax.config.update("jax_platforms", "my_plugin")

JAX peut avoir plusieurs backends/plug-ins. Vous avez plusieurs options pour vous assurer que votre plug-in est utilisé comme backend par défaut:

  • Option 1: exécuter jax.config.update("jax_platforms", "my_plugin") au début du programme.
  • Option 2: définir ENV JAX_PLATFORMS=my_plugin.
  • Option 3: définir une priorité suffisamment élevée lors de l'appel de xb.register_plugin (la valeur par défaut est 400, qui est supérieure à celle des autres backends existants). Notez que le backend ayant la priorité la plus élevée n'est utilisé que lorsque JAX_PLATFORMS=''. La valeur par défaut de JAX_PLATFORMS est '', mais elle est parfois écrasée.

Tester avec JAX

Voici quelques scénarios de test de base à essayer:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Nous ajouterons bientôt des instructions pour exécuter les tests unitaires jax sur votre plug-in.)

Pour obtenir d'autres exemples de plug-ins PJRT, consultez la section Exemples PJRT.