Contexto
PJRT é a API uniforme de dispositivos que queremos adicionar ao ecossistema de ML. A visão de longo prazo é que: (1) as estruturas (JAX, TF etc.) chamarão o PJRT, que tem implementações específicas de dispositivo que são opacas para as estruturas; (2) cada dispositivo se concentra na implementação de APIs PJRT e pode ser opaca para as estruturas.
Este documento tem como foco as recomendações sobre como fazer a integração com o PJRT e como testar a integração dele com o JAX.
Como fazer a integração com o PJRT
Etapa 1: implementar a interface da API PJRT C
Opção A: você pode implementar a API PJRT C diretamente.
Opção B: se você consegue criar com o código C++ no repositório xla (via bifurcação ou bazel), também é possível implementar a API PJRT C++ e usar o wrapper C→C++:
- Implementar um cliente PJRT em C++ que herda do cliente PJRT de base (e as classes PJRT relacionadas). Confira alguns exemplos de cliente PJRT em C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implemente alguns métodos da API C que não fazem parte do cliente PJRT do C++:
- PJRT_Client_Create. Confira abaixo um exemplo de pseudocódigo, supondo que
GetPluginPjRtClient
retorne um cliente PJRT C++ implementado acima:
- PJRT_Client_Create. Confira abaixo um exemplo de pseudocódigo, supondo que
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
PJRT_Client_Create pode receber opções transmitidas do framework. Confira um exemplo de como um cliente de GPU usa esse recurso.
- [Opcional] PJRT_TopologyDescription_Create.
- [Opcional] PJRT_Plugin_Initialize. Essa é uma configuração de plug-in única que será chamada pelo framework antes que qualquer outra função seja chamada.
- [Opcional] PJRT_Plugin_Attributes.
Com o wrapper, não é necessário implementar as APIs C restantes.
Etapa 2: implementar GetPjRtApi
Você precisa implementar um método GetPjRtApi
que retorna um PJRT_Api*
contendo ponteiros de função para implementações da API PJRT C. Confira abaixo um exemplo que pressupõe a implementação por meio de um wrapper (semelhante a pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Etapa 3: testar as implementações da API C
Você pode chamar RegisterPjRtCApiTestFactory para executar um pequeno conjunto de testes para comportamentos básicos da API PJRT C.
Como usar um plug-in PJRT do JAX
Etapa 1: configurar o JAX
Você pode usar o JAX todas as noites
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
ou crie o JAX a partir da origem.
Por enquanto, você precisa fazer a correspondência entre a versão do jaxlib e a da API PJRT C. Geralmente, é suficiente usar uma versão noturna do jaxlib no mesmo dia da confirmação do TF em que você está criando o plug-in, por exemplo,
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Também é possível criar um jaxlib da origem exatamente no commit do XLA em que você está criando (instructions).
Começaremos a oferecer suporte à compatibilidade com ABI em breve.
Etapa 2: usar o namespace jax_plugins ou configurar o input_point
Há duas opções para o seu plug-in ser descoberto pelo JAX.
- Como usar pacotes de namespace (ref). Defina um módulo globalmente exclusivo no pacote de namespace
jax_plugins
. Ou seja, crie um diretóriojax_plugins
e defina seu módulo abaixo dele. Veja um exemplo de estrutura de diretórios:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Usando metadados de pacote (ref). Se você estiver criando um pacote usando pyproject.toml ou setup.py, anuncie o nome do módulo do plug-in incluindo um ponto de entrada no grupo
jax_plugins
que aponte para o nome completo dele. Veja um exemplo via pyproject.toml ou setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Veja exemplos de como o openxla-pjrt-plugin é implementado usando a Opção 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119,,https://github.com/openxla/openxla-pjrt-plugin/pull/120
Etapa 3: implementar um método inicializado()
Você precisa implementar um método inicializado() no módulo Python para registrar o plug-in, por exemplo:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Consulte este link sobre como usar o xla_bridge.register_plugin
. Atualmente, é um método particular. Uma API pública será lançada no futuro.
Você pode executar a linha abaixo para verificar se o plug-in está registrado e gerar um erro se não for possível carregá-lo.
jax.config.update("jax_platforms", "my_plugin")
O JAX pode ter vários back-ends/plug-ins. Há algumas opções para garantir que o plug-in seja usado como back-end padrão:
- Opção 1: executar
jax.config.update("jax_platforms", "my_plugin")
no início do programa. - Opção 2: definir ENV
JAX_PLATFORMS=my_plugin
- Opção 3: defina uma prioridade alta o suficiente ao chamar xb.register_plugin. O valor padrão é 400, que é maior do que os outros back-ends atuais. O back-end com a prioridade mais alta será usado somente quando
JAX_PLATFORMS=''
. O valor padrão deJAX_PLATFORMS
é''
, mas às vezes ele é substituído.
Como testar com o JAX
Confira alguns casos de teste básicos:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
Em breve, adicionaremos instruções para executar os testes de unidade de jax no seu plug-in.
Exemplo: plug-in JAX CUDA
- Implementação da API PJRT C pelo wrapper (pjrt_c_api_gpu.h).
- Configure o ponto de entrada do pacote (setup.py).
- Implemente um método rename() (__init__.py).
- Pode ser testado com qualquer teste jax para CUDA. ```