Integración del complemento de PJRT

Información general

PJRT es la API uniforme de dispositivos que queremos agregar al ecosistema de AA. La visión a largo plazo es que (1) los frameworks (JAX, TF, etc.) llamarán a PJRT, que tiene implementaciones específicas del dispositivo que son opacas para los frameworks; (2) cada dispositivo se enfoca en implementar APIs de PJRT y puede ser opaca para los frameworks.

Este documento se centra en las recomendaciones sobre cómo realizar la integración con PJRT y cómo probar la integración de PJRT con JAX.

Cómo realizar la integración con PJRT

Paso 1: Implementa la interfaz de la API de PJRT C

Opción A: Puedes implementar la API de PJRT C directamente.

Opción B: Si puedes compilar con código C++ en el repositorio de xla (mediante la bifurcación o Bazel), también puedes implementar la API de PJRT C++ y usar el wrapper C→C++:

  1. Implementa un cliente PJRT de C++ que herede del cliente base PJRT (y de las clases PJRT relacionadas). Estos son algunos ejemplos de un cliente PJRT de C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implementa algunos métodos de la API de C que no forman parte del cliente de C++ PJRT:
    • PJRT_Client_Create. A continuación, se muestra un pseudocódigo de ejemplo (suponiendo que GetPluginPjRtClient muestra un cliente C++ PJRT implementado anteriormente):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Ten en cuenta que PJRT_Client_Create puede tomar opciones pasadas desde el framework. Aquí hay un ejemplo de cómo un cliente de GPU usa esta función.

Con el wrapper, no necesitas implementar las APIs de C restantes.

Paso 2: Implementa GetPjRtApi

Debes implementar un método GetPjRtApi que muestre un PJRT_Api* que contenga punteros de función a las implementaciones de la API de PJRT C. A continuación, se muestra un ejemplo de la implementación a través de un wrapper (similar a pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Paso 3: Prueba las implementaciones de la API de C

Puedes llamar a RegisterPjRtCApiTestFactory para ejecutar un conjunto pequeño de pruebas de comportamientos básicos de la API de PJRT C.

Cómo usar un complemento PJRT de JAX

Paso 1: Configura JAX

Puedes usar JAX todas las noches

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

o compila JAX desde el origen.

Por ahora, debes hacer coincidir la versión de jaxlib con la versión de la API de PJRT C. Por lo general, es suficiente usar una versión nocturna de jaxlib del mismo día que la confirmación de TF en la que compilas tu complemento, p.ej.,

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

También puedes compilar un jaxlib desde la fuente en la confirmación de XLA exacta en función de la que compilas (instructions).

Pronto comenzaremos a admitir la compatibilidad con ABI.

Paso 2: Usa el espacio de nombres jax_plugins o configura input_point

Hay dos opciones para que JAX descubra tu complemento.

  1. Usar paquetes de espacio de nombres (ref). Define un módulo único a nivel global en el paquete de espacio de nombres jax_plugins (es decir, solo crea un directorio jax_plugins y define tu módulo debajo). A continuación, se muestra un ejemplo de una estructura de directorio:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Uso de metadatos del paquete (ref). Si compilas un paquete mediante pyproject.toml o setup.py, anuncia el nombre del módulo del complemento incluyendo un punto de entrada debajo del grupo jax_plugins que apunte al nombre completo del módulo. Aquí tienes un ejemplo a través de pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Estos son ejemplos de cómo se implementa openxla-pjrt-plugin con la opción 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120.

Paso 3: Implementa un método inicializa()

Debes implementar un método inicializa() en tu módulo de Python para registrar el complemento, por ejemplo:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Consulta aquí para obtener información sobre cómo usar xla_bridge.register_plugin. Actualmente, es un método privado. Más adelante se lanzará una API pública.

Puedes ejecutar la siguiente línea para verificar que el complemento esté registrado y generar un error si no se puede cargar.

jax.config.update("jax_platforms", "my_plugin")

JAX puede tener varios backends o complementos. Existen algunas opciones para garantizar que tu complemento se use como backend predeterminado:

  • Opción 1: Ejecuta jax.config.update("jax_platforms", "my_plugin") al comienzo del programa.
  • Opción 2: Establece ENV JAX_PLATFORMS=my_plugin.
  • Opción 3: Establece una prioridad lo suficientemente alta cuando llames a xb.register_plugin (el valor predeterminado es 400, que es mayor que el de otros backends existentes). Ten en cuenta que el backend con prioridad más alta solo se usará cuando JAX_PLATFORMS=''. El valor predeterminado de JAX_PLATFORMS es '', pero a veces se reemplazará.

Cómo realizar pruebas con JAX

Estos son algunos casos de prueba básicos:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Pronto agregaremos instrucciones para ejecutar las pruebas de unidades Jax en tu complemento)

Ejemplo: Complemento JAX CUDA

  1. Implementación de la API de PJRT C a través de un wrapper (pjrt_c_api_gpu.h).
  2. Configura el punto de entrada para el paquete (setup.py).
  3. Implementa un método inicializa() (__init__.py).
  4. Se puede probar con cualquier prueba de Jax para CUDA. ```