Intégration du plug-in PJRT

Contexte

PJRT est l'API uniforme d'appareil que nous souhaitons ajouter à l'écosystème de ML. La vision à long terme est la suivante: (1) les frameworks (JAX, TF, etc.) appelleront PJRT, qui dispose d'implémentations spécifiques à l'appareil qui sont opaques par rapport aux frameworks ; (2) chaque appareil se concentre sur l'implémentation des API PJRT et peut être opaque pour les frameworks.

Ce document se concentre sur les recommandations sur la manière d'intégrer PJRT et de tester l'intégration de PJRT avec JAX.

Intégrer PJRT

Étape 1: Implémenter l'interface de l'API PJRT C

Option A: vous pouvez implémenter directement l'API PJRT C.

Option B: si vous pouvez compiler avec du code C++ dans le dépôt xla (via la duplication ou bazel), vous pouvez également implémenter l'API PJRT C++ et utiliser le wrapper C→C++:

  1. Implémentez un client PJRT C++ héritant du client PJRT de base (et des classes PJRT associées). Voici quelques exemples de client PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implémentez quelques méthodes d'API C qui ne font pas partie du client PJRT C++ :
    • PJRT_Client_Create : Voici un exemple de pseudo-code (en supposant que GetPluginPjRtClient renvoie un client PJRT C++ implémenté ci-dessus):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Notez que PJRT_Client_Create peut utiliser les options transmises à partir du framework. Voici un exemple d'utilisation de cette fonctionnalité par un client GPU.

Avec le wrapper, vous n'avez pas besoin d'implémenter les autres API C.

Étape 2: Implémenter GetPjRtApi

Vous devez implémenter une méthode GetPjRtApi qui renvoie un PJRT_Api* contenant des pointeurs de fonction vers les implémentations de l'API PJRT C. Vous trouverez ci-dessous un exemple d'implémentation via un wrapper (semblable à pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Étape 3: Tester les implémentations de l'API C

Vous pouvez appeler RegisterPjRtCApiTestFactory pour exécuter un petit ensemble de tests sur les comportements de base de l'API PJRT C.

Utiliser un plug-in PJRT de JAX

Étape 1: Configurer JAX

Vous pouvez soit utiliser JAX,

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

ou compilez JAX à partir de la source.

Pour l'instant, vous devez faire correspondre la version de jaxlib à la version de l'API PJRT C. Il est généralement suffisant d'utiliser une version nocturne jaxlib du même jour que le commit TF avec lequel vous compilez votre plug-in, par exemple

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Vous pouvez également compiler un jaxlib à partir de la source au niveau du commit XLA utilisé pour la compilation (instructions).

La compatibilité avec les ABI sera bientôt disponible.

Étape 2: Utiliser l'espace de noms jax_plugins ou configurer le point d'entrée

Deux options permettent à JAX de faire découvrir votre plug-in.

  1. Utiliser des packages d'espace de noms (ref) Définissez un module unique sous le package d'espace de noms jax_plugins (par exemple, créez simplement un répertoire jax_plugins et définissez votre module en dessous). Voici un exemple de structure de répertoire:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Avec les métadonnées de package (ref). Si vous créez un package via pyproject.toml ou setup.py, annoncez le nom du module de votre plug-in en incluant un point d'entrée sous le groupe jax_plugins qui pointe vers le nom complet de votre module. Voici un exemple via pyproject.toml ou setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Voici des exemples d'implémentation d'openxla-pjrt-plugin avec l'option 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120

Étape 3: Implémenter une méthode initial()

Vous devez implémenter une méthode "initial()" dans votre module Python pour enregistrer le plug-in, par exemple:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Consultez cette page pour découvrir comment utiliser xla_bridge.register_plugin. Il s'agit actuellement d'une méthode privée. Une API publique sera disponible à l'avenir.

Vous pouvez exécuter la ligne ci-dessous pour vérifier que le plug-in est enregistré et générer une erreur s'il ne peut pas être chargé.

jax.config.update("jax_platforms", "my_plugin")

JAX peut disposer de plusieurs backends/plug-ins. Plusieurs options s'offrent à vous pour vous assurer que votre plug-in est utilisé comme backend par défaut:

  • Option 1: exécutez jax.config.update("jax_platforms", "my_plugin") au début du programme.
  • Option 2: définissez ENV JAX_PLATFORMS=my_plugin.
  • Option 3: définissez une priorité suffisamment élevée lorsque vous appelez xb.register_plugin (la valeur par défaut est 400, ce qui est supérieur à celui des autres backends existants). Notez que le backend avec la priorité la plus élevée n'est utilisé que lorsque JAX_PLATFORMS=''. La valeur par défaut de JAX_PLATFORMS est '', mais elle est parfois remplacée.

Effectuer des tests avec JAX

Voici quelques scénarios de test de base à essayer:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Nous ajouterons bientôt des instructions pour exécuter les tests unitaires Jax avec votre plug-in !)

Exemple: plug-in JAX CUDA

  1. Implémentation de l'API PJRT via un wrapper (pjrt_c_api_gpu.h)
  2. Configurez le point d'entrée du package (setup.py).
  3. Implémentez une méthode "initial()" (__init__.py).
  4. Peut être testé avec n'importe quel test Jax pour CUDA. ```