Integrasi plugin PJRT

Dokumen ini berfokus pada rekomendasi tentang cara berintegrasi dengan PJRT, dan cara menguji integrasi PJRT dengan JAX.

Cara berintegrasi dengan PJRT

Langkah 1: Terapkan antarmuka PJRT C API

Opsi A: Anda dapat menerapkan PJRT C API secara langsung.

Opsi B: Jika Anda dapat mem-build dengan kode C++ di repo xla (melalui forking atau bazel), Anda juga dapat menerapkan PJRT C++ API dan menggunakan wrapper C→C++:

  1. Terapkan klien PJRT C++ yang mewarisi dari klien PJRT dasar (dan class PJRT terkait). Berikut beberapa contoh klien PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Terapkan beberapa metode C API yang bukan bagian dari klien PJRT C++:
    • PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi GetPluginPjRtClient menampilkan klien PJRT C++ yang diimplementasikan di atas):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Perhatikan bahwa PJRT_Client_Create dapat menggunakan opsi yang diteruskan dari framework. Berikut adalah contoh cara klien GPU menggunakan fitur ini.

Dengan wrapper, Anda tidak perlu menerapkan C API yang tersisa.

Langkah 2: Terapkan GetPjRtApi

Anda perlu menerapkan metode GetPjRtApi yang menampilkan PJRT_Api* yang berisi pointer fungsi ke penerapan PJRT C API. Berikut adalah contoh yang mengasumsikan penerapan melalui wrapper (mirip dengan pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Langkah 3: Uji implementasi C API

Anda dapat memanggil RegisterPjRtCApiTestFactory untuk menjalankan serangkaian kecil pengujian untuk perilaku PJRT C API dasar.

Cara menggunakan plugin PJRT dari JAX

Langkah 1: Siapkan JAX

Anda dapat menggunakan JAX nightly

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

atau mem-build JAX dari sumber.

Untuk saat ini, Anda harus mencocokkan versi jaxlib dengan versi PJRT C API. Biasanya cukup menggunakan versi jaxlib nightly dari hari yang sama dengan commit TF yang digunakan untuk mem-build plugin, misalnya

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Anda juga dapat mem-build jaxlib dari sumber tepat pada commit XLA yang Anda build (petunjuk).

Kami akan segera mulai mendukung kompatibilitas ABI.

Langkah 2: Gunakan namespace jax_plugins atau siapkan entry_point

Ada dua opsi agar plugin Anda dapat ditemukan oleh JAX.

  1. Menggunakan paket namespace (ref). Tentukan modul yang unik secara global di bawah paket namespace jax_plugins (yaitu, cukup buat direktori jax_plugins dan tentukan modul Anda di bawahnya). Berikut adalah contoh struktur direktori:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Menggunakan metadata paket (ref). Jika mem-build paket melalui pyproject.toml atau setup.py, iklankan nama modul plugin Anda dengan menyertakan titik entri di bagian grup jax_plugins yang mengarah ke nama modul lengkap Anda. Berikut adalah contoh melalui pyproject.toml atau setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Berikut adalah contoh cara openxla-pjrt-plugin diterapkan menggunakan Opsi 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Langkah 3: Terapkan metode initialize()

Anda perlu menerapkan metode initialize() di modul python untuk mendaftarkan plugin, misalnya:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Lihat di sini untuk mengetahui cara menggunakan xla_bridge.register_plugin. Saat ini, metode ini bersifat pribadi. API publik akan dirilis pada masa mendatang.

Anda dapat menjalankan baris di bawah untuk memverifikasi bahwa plugin terdaftar dan menampilkan error jika tidak dapat dimuat.

jax.config.update("jax_platforms", "my_plugin")

JAX dapat memiliki beberapa backend/plugin. Ada beberapa opsi untuk memastikan plugin Anda digunakan sebagai backend default:

  • Opsi 1: jalankan jax.config.update("jax_platforms", "my_plugin") di awal program.
  • Opsi 2: tetapkan ENV JAX_PLATFORMS=my_plugin.
  • Opsi 3: tetapkan prioritas yang cukup tinggi saat memanggil xb.register_plugin (nilai defaultnya adalah 400 yang lebih tinggi dari backend lain yang ada). Perhatikan bahwa backend dengan prioritas tertinggi hanya akan digunakan saat JAX_PLATFORMS=''. Nilai default JAX_PLATFORMS adalah '', tetapi terkadang akan ditimpa.

Cara menguji dengan JAX

Beberapa kasus pengujian dasar yang dapat dicoba:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Kami akan segera menambahkan petunjuk untuk menjalankan pengujian unit jax terhadap plugin Anda!)

Untuk mengetahui contoh plugin PJRT lainnya, lihat Contoh PJRT.