En este documento, se centran las recomendaciones sobre cómo realizar la integración con PJRT y cómo probar la integración de PJRT con JAX.
Cómo realizar la integración con PJRT
Paso 1: Implementa la interfaz de la API de PJRT C
Opción A: Puedes implementar la API de PJRT C directamente.
Opción B: Si puedes compilar con código C++ en el repositorio de xla (a través de la bifurcación o bazel), también puedes implementar la API de PJRT C++ y usar el wrapper C→C++:
- Implementa un cliente de PJRT de C++ que herede del cliente de PJRT base (y las clases de PJRT relacionadas). Estos son algunos ejemplos de clientes de PJRT de C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementa algunos métodos de la API de C que no forman parte del cliente de PJRT de C++:
- PJRT_Client_Create. A continuación, se muestra un ejemplo de pseudocódigo (suponiendo que
GetPluginPjRtClient
devuelve un cliente de PJRT de C++ implementado anteriormente):
- PJRT_Client_Create. A continuación, se muestra un ejemplo de pseudocódigo (suponiendo que
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Ten en cuenta que PJRT_Client_Create puede aceptar opciones que se pasan desde el framework. Este es un ejemplo de cómo un cliente de GPU usa esta función.
- [Opcional] PJRT_TopologyDescription_Create.
- [Opcional] PJRT_Plugin_Initialize. Esta es una configuración de complemento única, a la que el framework llamará antes de llamar a cualquier otra función.
- [Opcional] PJRT_Plugin_Attributes.
Con el wrapper, no necesitas implementar las APIs de C restantes.
Paso 2: Implementa GetPjRtApi
Debes implementar un método GetPjRtApi
que devuelva un PJRT_Api*
que contenga punteros de función a implementaciones de la API de PJRT C. A continuación, se muestra un ejemplo que supone la implementación a través de un wrapper (similar a pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Paso 3: Prueba las implementaciones de la API de C
Puedes llamar a RegisterPjRtCApiTestFactory para ejecutar un pequeño conjunto de pruebas de los comportamientos básicos de la API de PJRT C.
Cómo usar un complemento de PJRT desde JAX
Paso 1: Configura JAX
Puedes usar JAX todas las noches.
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
o compila JAX desde la fuente.
Por ahora, debes hacer coincidir la versión de jaxlib con la versión de la API de PJRT C. Por lo general, es suficiente con usar una versión nocturna de jaxlib del mismo día que la confirmación de TF con la que compilas tu complemento, p.ej.,
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
También puedes compilar un jaxlib desde la fuente en la confirmación de XLA exacta para la que compilas (instrucciones).
Pronto comenzaremos a admitir la compatibilidad con ABI.
Paso 2: Usa el espacio de nombres jax_plugins o configura entry_point
Existen dos opciones para que JAX descubra tu complemento.
- Con paquetes de espacio de nombres (ref). Define un módulo único a nivel global en el paquete del espacio de nombres
jax_plugins
(es decir, crea un directoriojax_plugins
y define tu módulo debajo de él). Este es un ejemplo de estructura de directorios:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Con los metadatos del paquete (ref). Si compilas un paquete a través de pyproject.toml o setup.py, incluye un punto de entrada en el grupo
jax_plugins
que apunte al nombre completo del módulo para promocionarlo. Este es un ejemplo a través de pyproject.toml o setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Estos son ejemplos de cómo se implementa openxla-pjrt-plugin con la opción 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Paso 3: Implementa un método initialize()
Debes implementar un método initialize() en tu módulo de Python para registrar el complemento, por ejemplo:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Consulta aquí para obtener información sobre cómo usar xla_bridge.register_plugin
. Actualmente, es un método privado. En el futuro, se lanzará una API pública.
Puedes ejecutar la siguiente línea para verificar que el complemento esté registrado y mostrar un error si no se puede cargar.
jax.config.update("jax_platforms", "my_plugin")
JAX puede tener varios backends o complementos. Existen algunas opciones para garantizar que tu complemento se use como backend predeterminado:
- Opción 1: Ejecuta
jax.config.update("jax_platforms", "my_plugin")
al principio del programa. - Opción 2: Establecer ENV
JAX_PLATFORMS=my_plugin
- Opción 3: Establece una prioridad lo suficientemente alta cuando llames a xb.register_plugin (el valor predeterminado es 400, que es más alto que otros backends existentes). Ten en cuenta que el backend con la prioridad más alta se usará solo cuando
JAX_PLATFORMS=''
. El valor predeterminado deJAX_PLATFORMS
es''
, pero a veces se reemplaza.
Cómo realizar pruebas con JAX
Estos son algunos casos de prueba básicos que puedes probar:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Pronto agregaremos instrucciones para ejecutar las pruebas de unidades de Jax en tu complemento).
Para obtener más ejemplos de complementos de PJRT, consulta Ejemplos de PJRT.