Интеграция плагина PJRT

Фон

PJRT — это унифицированный API-интерфейс устройства, который мы хотим добавить в экосистему машинного обучения. В долгосрочной перспективе: (1) фреймворки (JAX, TF и ​​т. д.) будут вызывать PJRT, который имеет реализации для конкретных устройств, непрозрачные для фреймворков; (2) каждое устройство ориентировано на реализацию API-интерфейсов PJRT и может быть непрозрачным для платформ.

В этом документе основное внимание уделяется рекомендациям по интеграции с PJRT и тестированию интеграции PJRT с JAX.

Как интегрироваться с PJRT

Шаг 1. Реализуйте интерфейс PJRT C API.

Вариант А : вы можете напрямую реализовать PJRT C API.

Вариант Б. Если вы можете выполнять сборку кода C++ в репозитории xla (с помощью разветвления или bazel), вы также можете реализовать API PJRT C++ и использовать оболочку C→C++:

  1. Реализуйте клиент PJRT C++, унаследованный от базового клиента PJRT (и связанных классов PJRT). Вот несколько примеров клиента PJRT C++: pjrt_stream_executor_client.h , tfrt_cpu_pjrt_client.h .
  2. Реализуйте несколько методов C API, которые не являются частью клиента C++ PJRT:
    • PJRT_Client_Create . Ниже приведен пример псевдокода (при условии, что GetPluginPjRtClient возвращает клиент C++ PJRT, реализованный выше):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Примечание. PJRT_Client_Create может принимать параметры, переданные из платформы. Вот пример того, как клиент GPU использует эту функцию.

С помощью оболочки вам не нужно реализовывать остальные C API.

Шаг 2. Реализуйте GetPjRtApi.

Вам необходимо реализовать метод GetPjRtApi , который возвращает PJRT_Api* , содержащий указатели функций на реализации PJRT C API. Ниже приведен пример, предполагающий реализацию через оболочку (аналогично pjrt_c_api_cpu.cc ):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Шаг 3. Тестирование реализации C API

Вы можете вызвать RegisterPjRtCApiTestFactory , чтобы запустить небольшой набор тестов для базового поведения PJRT C API.

Как использовать плагин PJRT от JAX

Шаг 1. Настройте JAX

Вы можете использовать JAX каждую ночь

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

или соберите JAX из исходного кода .

На данный момент вам нужно сопоставить версию jaxlib с версией PJRT C API. Обычно достаточно использовать ночную версию jaxlib, выпущенную в тот же день, что и фиксация TF, на основе которой вы создаете свой плагин, например

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Вы также можете собрать jaxlib из исходного кода именно с тем коммитом XLA, который вы создаете ( инструкции ).

Скоро мы начнем поддерживать совместимость с ABI.

Шаг 2. Используйте пространство имен jax_plugins или настройте точку входа.

JAX может обнаружить ваш плагин двумя способами.

  1. Использование пакетов пространства имен ( ссылка ). Определите глобально уникальный модуль в пакете пространства имен jax_plugins (т.е. просто создайте каталог jax_plugins и определите свой модуль под ним). Вот пример структуры каталогов:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Использование метаданных пакета ( ссылка ). При сборке пакета с помощью pyproject.toml или setup.py объявите имя вашего модуля плагина, включив точку входа в группу jax_plugins , которая указывает на полное имя вашего модуля. Вот пример через pyproject.toml или setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Вот примеры реализации плагина openxla-pjrt с использованием варианта 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt- плагин/тянуть/120

Шаг 3. Реализуйте метод инициализации().

Вам необходимо реализовать метод инициализации() в вашем модуле Python, чтобы зарегистрировать плагин, например:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Пожалуйста, обратитесь сюда , чтобы узнать, как использовать xla_bridge.register_plugin . В настоящее время это частный метод. Публичный API будет выпущен в будущем.

Вы можете запустить строку ниже, чтобы убедиться, что плагин зарегистрирован, и вызвать ошибку, если его невозможно загрузить.

jax.config.update("jax_platforms", "my_plugin")

JAX может иметь несколько серверов/плагинов. Есть несколько вариантов, позволяющих использовать ваш плагин в качестве бэкэнда по умолчанию:

  • Вариант 1: запустить jax.config.update("jax_platforms", "my_plugin") в начале программы.
  • Вариант 2: установите ENV JAX_PLATFORMS=my_plugin .
  • Вариант 3: установите достаточно высокий приоритет при вызове xb.register_plugin (значение по умолчанию — 400, что выше, чем у других существующих бэкэндов). Обратите внимание, что серверная часть с наивысшим приоритетом будет использоваться только в том случае, если JAX_PLATFORMS='' . Значением по умолчанию JAX_PLATFORMS является '' но иногда оно перезаписывается.

Как тестировать с помощью JAX

Некоторые основные тестовые примеры, которые стоит попробовать:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Скоро мы добавим инструкции по запуску модульных тестов jax для вашего плагина!)

Пример: плагин JAX CUDA.

  1. Реализация PJRT C API через оболочку ( pjrt_c_api_gpu.h ).
  2. Настройте точку входа для пакета ( setup.py ).
  3. Реализуйте метод инициализации() ( __init__.py ).
  4. Можно протестировать любыми jax-тестами для CUDA. ```