Integração do plug-in PJRT

Este documento se concentra nas recomendações sobre como integrar com o PJRT e como testar a integração do PJRT com o JAX.

Como fazer a integração com o PJRT

Etapa 1: implementar a interface da API PJRT C

Opção A: é possível implementar a API PJRT C diretamente.

Opção B: se você conseguir criar um build com base no código C++ no repo xla (por bifurcação ou bazel), também será possível implementar a API PJRT C++ e usar o wrapper C→C++:

  1. Implementar um cliente PJRT C++ que herda do cliente PJRT base (e classes PJRT relacionadas). Confira alguns exemplos de cliente PJRT em C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implemente alguns métodos da API C que não fazem parte do cliente PJRT do C++:
    • PJRT_Client_Create. Confira abaixo um exemplo de pseudocódigo (assumindo que GetPluginPjRtClient retorna um cliente PJRT C++):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

A nota PJRT_Client_Create pode receber opções transmitidas do framework. Confira um exemplo de como um cliente de GPU usa esse recurso.

Com o wrapper, você não precisa implementar as APIs C restantes.

Etapa 2: implementar a GetPjRtApi

É necessário implementar um método GetPjRtApi que retorne um PJRT_Api* contendo ponteiros de função para implementações da API PJRT C. Confira abaixo um exemplo que pressupõe a implementação pelo wrapper (semelhante a pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Etapa 3: testar implementações de API C

Chame RegisterPjRtCApiTestFactory para executar um pequeno conjunto de testes de comportamentos básicos da API PJRT C.

Como usar um plug-in PJRT do JAX

Etapa 1: configurar o JAX

Você pode usar o JAX diariamente

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

ou criar o JAX a partir da fonte.

Por enquanto, você precisa corresponder a versão do jaxlib à versão da API PJRT C. Normalmente, é suficiente usar uma versão noturna do jaxlib do mesmo dia do commit do TF em que você está criando o plug-in, por exemplo.

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Você também pode criar uma jaxlib a partir da fonte exatamente no commit da XLA que está sendo criado (instruções).

Em breve, vamos começar a oferecer suporte à compatibilidade com ABI.

Etapa 2: usar o namespace jax_plugins ou configurar o entry_point

Há duas opções para que o plug-in seja descoberto pelo JAX.

  1. Usando pacotes de namespace (ref). Defina um módulo globalmente exclusivo no pacote de namespace jax_plugins. Para isso, basta criar um diretório jax_plugins e definir o módulo abaixo dele. Confira um exemplo de estrutura de diretório:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Usando metadados de pacote (ref). Se você estiver criando um pacote usando pyproject.toml ou setup.py, anuncie o nome do módulo do plug-in incluindo um ponto de entrada no grupo jax_plugins que aponte para o nome completo do módulo. Confira um exemplo usando pyproject.toml ou setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Confira exemplos de como o openxla-pjrt-plugin é implementado usando a opção 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Etapa 3: implementar um método initialize()

Você precisa implementar um método initialize() no módulo Python para registrar o plug-in, por exemplo:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Consulte este link para saber como usar xla_bridge.register_plugin. No momento, é um método particular. Uma API pública será lançada no futuro.

Você pode executar a linha abaixo para verificar se o plug-in está registrado e gerar um erro se ele não puder ser carregado.

jax.config.update("jax_platforms", "my_plugin")

O JAX pode ter vários back-ends/plug-ins. Há algumas opções para garantir que o plug-in seja usado como o back-end padrão:

  • Opção 1: execute jax.config.update("jax_platforms", "my_plugin") no início do programa.
  • Opção 2: defina o ENV JAX_PLATFORMS=my_plugin.
  • Opção 3: defina uma prioridade alta ao chamar xb.register_plugin (o valor padrão é 400, maior do que outros back-ends existentes). O back-end com a maior prioridade só será usado quando JAX_PLATFORMS=''. O valor padrão de JAX_PLATFORMS é '', mas às vezes ele é substituído.

Como testar com o JAX

Alguns casos de teste básicos para tentar:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

Em breve, vamos adicionar instruções para executar os testes de unidade do Jax no seu plug-in.

Para mais exemplos de plug-ins PJRT, consulte Exemplos de PJRT.