Integrasi plugin PJRT

Latar belakang

PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjang adalah: (1) framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang buram terhadap framework; (2) setiap perangkat berfokus pada penerapan PJRT API, dan dapat menjadi buram terhadap framework.

Dokumen ini berfokus pada rekomendasi tentang cara melakukan integrasi dengan PJRT, dan cara menguji integrasi PJRT dengan JAX.

Cara berintegrasi dengan PJRT

Langkah 1: Terapkan antarmuka PJRT C API

Opsi A: Anda dapat menerapkan PJRT C API secara langsung.

Opsi B: Jika Anda dapat mem-build dengan kode C++ di repo xla (melalui forking atau bazel), Anda juga dapat menerapkan PJRT C++ API dan menggunakan wrapper C→C++:

  1. Implementasikan klien PJRT C++ yang mewarisi dari klien PJRT dasar (dan class PJRT terkait). Berikut adalah beberapa contoh klien PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Implementasikan beberapa metode C API yang bukan bagian dari klien PJRT C++:
    • PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi GetPluginPjRtClient menampilkan klien PJRT C++ yang diterapkan di atas):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Perlu diperhatikan bahwa PJRT_Client_Create dapat mengambil opsi yang diteruskan dari framework. Berikut adalah contoh bagaimana klien GPU menggunakan fitur ini.

Dengan wrapper, Anda tidak perlu menerapkan C API yang tersisa.

Langkah 2: Implementasikan GetPjRtApi

Anda perlu menerapkan metode GetPjRtApi yang menampilkan PJRT_Api* yang berisi pointer fungsi ke implementasi PJRT C API. Berikut adalah contoh dengan asumsi menerapkan melalui wrapper (mirip dengan pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Langkah 3: Menguji implementasi C API

Anda dapat memanggil RegisterPjRtCApiTestFactory untuk menjalankan serangkaian kecil pengujian untuk perilaku PJRT C API dasar.

Cara menggunakan plugin PJRT dari JAX

Langkah 1: Siapkan JAX

Anda bisa menggunakan JAX setiap malam

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

atau buat JAX dari sumber.

Untuk saat ini, Anda harus mencocokkan versi jaxlib dengan versi PJRT C API. Biasanya cukup untuk menggunakan versi jaxlib Nightly dari hari yang sama dengan commit TF yang Anda gunakan untuk membuat plugin, misalnya

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Anda juga dapat mem-build jaxlib dari sumber persis dengan commit XLA yang digunakan untuk membangun (instructions).

Kami akan segera mulai mendukung kompatibilitas ABI.

Langkah 2: Gunakan namespace jax_plugins atau siapkan entry_point

Ada dua opsi agar plugin Anda ditemukan oleh JAX.

  1. Menggunakan paket namespace (ref). Tentukan modul unik secara global pada paket namespace jax_plugins (yaitu cukup buat direktori jax_plugins dan tentukan modul Anda di bawahnya). Berikut adalah contoh struktur direktori:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Menggunakan metadata paket (ref). Jika membuat paket melalui pyproject.toml atau setup.py, iklankan nama modul plugin Anda dengan menyertakan titik entri pada grup jax_plugins yang mengarah ke nama modul lengkap Anda. Berikut ini contohnya melalui pyproject.toml atau setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Berikut adalah contoh cara implementasi openxla-pjrt-plugin menggunakan Opsi 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Langkah 3: Implementasikan metode drawable()

Anda perlu mengimplementasikan metode initialization() dalam modul python untuk mendaftarkan plugin, misalnya:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Harap lihat di sini tentang cara menggunakan xla_bridge.register_plugin. Saat ini, metode tersebut bersifat pribadi. API publik akan dirilis pada masa mendatang.

Anda dapat menjalankan baris di bawah untuk memverifikasi bahwa plugin terdaftar dan membuat error jika plugin tidak dapat dimuat.

jax.config.update("jax_platforms", "my_plugin")

JAX dapat memiliki beberapa backend/plugin. Ada beberapa opsi untuk memastikan plugin Anda digunakan sebagai backend default:

  • Opsi 1: jalankan jax.config.update("jax_platforms", "my_plugin") di awal program.
  • Opsi 2: tetapkan ENV JAX_PLATFORMS=my_plugin.
  • Opsi 3: tetapkan prioritas yang cukup tinggi saat memanggil xb.register_plugin (nilai defaultnya adalah 400, yang lebih tinggi daripada backend lain yang ada). Perhatikan bahwa backend dengan prioritas tertinggi hanya akan digunakan saat JAX_PLATFORMS=''. Nilai default JAX_PLATFORMS adalah '', tetapi terkadang akan ditimpa.

Cara menguji dengan JAX

Beberapa kasus pengujian dasar yang dapat dicoba:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Kami akan segera menambahkan petunjuk untuk menjalankan pengujian unit jax terhadap plugin Anda.)

Contoh: Plugin JAX CUDA

  1. Implementasi PJRT C API melalui wrapper (pjrt_c_api_gpu.h).
  2. Siapkan titik entri untuk paket (setup.py).
  3. Mengimplementasikan metode inisialisasi() (__init__.py).
  4. Dapat diuji dengan pengujian jax apa pun untuk CUDA. ```