Latar belakang
PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjang adalah: (1) framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang buram terhadap framework; (2) setiap perangkat berfokus pada penerapan PJRT API, dan dapat menjadi buram terhadap framework.
Dokumen ini berfokus pada rekomendasi tentang cara melakukan integrasi dengan PJRT, dan cara menguji integrasi PJRT dengan JAX.
Cara berintegrasi dengan PJRT
Langkah 1: Terapkan antarmuka PJRT C API
Opsi A: Anda dapat menerapkan PJRT C API secara langsung.
Opsi B: Jika Anda dapat mem-build dengan kode C++ di repo xla (melalui forking atau bazel), Anda juga dapat menerapkan PJRT C++ API dan menggunakan wrapper C→C++:
- Implementasikan klien PJRT C++ yang mewarisi dari klien PJRT dasar (dan class PJRT terkait). Berikut adalah beberapa contoh klien PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Implementasikan beberapa metode C API yang bukan bagian dari klien PJRT C++:
- PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi
GetPluginPjRtClient
menampilkan klien PJRT C++ yang diterapkan di atas):
- PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Perlu diperhatikan bahwa PJRT_Client_Create dapat mengambil opsi yang diteruskan dari framework. Berikut adalah contoh bagaimana klien GPU menggunakan fitur ini.
- [Opsional] PJRT_TopologyDescription_Create.
- [Opsional] PJRT_Plugin_Initialize. Ini adalah penyiapan plugin satu kali, yang akan dipanggil oleh framework sebelum fungsi lain dipanggil.
- [Opsional] PJRT_Plugin_Attributes.
Dengan wrapper, Anda tidak perlu menerapkan C API yang tersisa.
Langkah 2: Implementasikan GetPjRtApi
Anda perlu menerapkan metode GetPjRtApi
yang menampilkan PJRT_Api*
yang berisi pointer fungsi ke implementasi PJRT C API. Berikut adalah contoh dengan asumsi menerapkan melalui wrapper (mirip dengan pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Langkah 3: Menguji implementasi C API
Anda dapat memanggil RegisterPjRtCApiTestFactory untuk menjalankan serangkaian kecil pengujian untuk perilaku PJRT C API dasar.
Cara menggunakan plugin PJRT dari JAX
Langkah 1: Siapkan JAX
Anda bisa menggunakan JAX setiap malam
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
atau buat JAX dari sumber.
Untuk saat ini, Anda harus mencocokkan versi jaxlib dengan versi PJRT C API. Biasanya cukup untuk menggunakan versi jaxlib Nightly dari hari yang sama dengan commit TF yang Anda gunakan untuk membuat plugin, misalnya
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Anda juga dapat mem-build jaxlib dari sumber persis dengan commit XLA yang digunakan untuk membangun (instructions).
Kami akan segera mulai mendukung kompatibilitas ABI.
Langkah 2: Gunakan namespace jax_plugins atau siapkan entry_point
Ada dua opsi agar plugin Anda ditemukan oleh JAX.
- Menggunakan paket namespace (ref). Tentukan modul unik secara global pada paket namespace
jax_plugins
(yaitu cukup buat direktorijax_plugins
dan tentukan modul Anda di bawahnya). Berikut adalah contoh struktur direktori:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Menggunakan metadata paket (ref). Jika membuat paket melalui pyproject.toml atau setup.py, iklankan nama modul plugin Anda dengan menyertakan titik entri pada grup
jax_plugins
yang mengarah ke nama modul lengkap Anda. Berikut ini contohnya melalui pyproject.toml atau setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Berikut adalah contoh cara implementasi openxla-pjrt-plugin menggunakan Opsi 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Langkah 3: Implementasikan metode drawable()
Anda perlu mengimplementasikan metode initialization() dalam modul python untuk mendaftarkan plugin, misalnya:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Harap lihat di sini tentang cara menggunakan xla_bridge.register_plugin
. Saat ini, metode tersebut bersifat pribadi. API publik akan dirilis pada masa mendatang.
Anda dapat menjalankan baris di bawah untuk memverifikasi bahwa plugin terdaftar dan membuat error jika plugin tidak dapat dimuat.
jax.config.update("jax_platforms", "my_plugin")
JAX dapat memiliki beberapa backend/plugin. Ada beberapa opsi untuk memastikan plugin Anda digunakan sebagai backend default:
- Opsi 1: jalankan
jax.config.update("jax_platforms", "my_plugin")
di awal program. - Opsi 2: tetapkan ENV
JAX_PLATFORMS=my_plugin
. - Opsi 3: tetapkan prioritas yang cukup tinggi saat memanggil xb.register_plugin (nilai defaultnya adalah 400, yang lebih tinggi daripada backend lain yang ada). Perhatikan bahwa backend dengan prioritas tertinggi hanya akan digunakan saat
JAX_PLATFORMS=''
. Nilai defaultJAX_PLATFORMS
adalah''
, tetapi terkadang akan ditimpa.
Cara menguji dengan JAX
Beberapa kasus pengujian dasar yang dapat dicoba:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Kami akan segera menambahkan petunjuk untuk menjalankan pengujian unit jax terhadap plugin Anda.)
Contoh: Plugin JAX CUDA
- Implementasi PJRT C API melalui wrapper (pjrt_c_api_gpu.h).
- Siapkan titik entri untuk paket (setup.py).
- Mengimplementasikan metode inisialisasi() (__init__.py).
- Dapat diuji dengan pengujian jax apa pun untuk CUDA. ```