Dokumen ini berfokus pada rekomendasi tentang cara berintegrasi dengan PJRT, dan cara menguji integrasi PJRT dengan JAX.
Cara berintegrasi dengan PJRT
Langkah 1: Terapkan antarmuka PJRT C API
Opsi A: Anda dapat menerapkan PJRT C API secara langsung.
Opsi B: Jika Anda dapat mem-build dengan kode C++ di repo xla (melalui forking atau bazel), Anda juga dapat menerapkan PJRT C++ API dan menggunakan wrapper C→C++:
- Terapkan klien PJRT C++ yang mewarisi dari klien PJRT dasar (dan class PJRT terkait). Berikut beberapa contoh klien PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Terapkan beberapa metode C API yang bukan bagian dari klien PJRT C++:
- PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi
GetPluginPjRtClient
menampilkan klien PJRT C++ yang diimplementasikan di atas):
- PJRT_Client_Create. Berikut adalah beberapa contoh kode pseudo (dengan asumsi
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Perhatikan bahwa PJRT_Client_Create dapat menggunakan opsi yang diteruskan dari framework. Berikut adalah contoh cara klien GPU menggunakan fitur ini.
- [Opsional] PJRT_TopologyDescription_Create.
- [Opsional] PJRT_Plugin_Initialize. Ini adalah penyiapan plugin satu kali, yang akan dipanggil oleh framework sebelum fungsi lainnya dipanggil.
- [Opsional] PJRT_Plugin_Attributes.
Dengan wrapper, Anda tidak perlu menerapkan C API yang tersisa.
Langkah 2: Terapkan GetPjRtApi
Anda perlu menerapkan metode GetPjRtApi
yang menampilkan PJRT_Api*
yang berisi pointer fungsi ke penerapan PJRT C API. Berikut adalah contoh yang mengasumsikan penerapan melalui wrapper (mirip dengan pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Langkah 3: Uji implementasi C API
Anda dapat memanggil RegisterPjRtCApiTestFactory untuk menjalankan serangkaian kecil pengujian untuk perilaku PJRT C API dasar.
Cara menggunakan plugin PJRT dari JAX
Langkah 1: Siapkan JAX
Anda dapat menggunakan JAX nightly
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
atau mem-build JAX dari sumber.
Untuk saat ini, Anda harus mencocokkan versi jaxlib dengan versi PJRT C API. Biasanya cukup menggunakan versi jaxlib nightly dari hari yang sama dengan commit TF yang digunakan untuk mem-build plugin, misalnya
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Anda juga dapat mem-build jaxlib dari sumber tepat pada commit XLA yang Anda build (petunjuk).
Kami akan segera mulai mendukung kompatibilitas ABI.
Langkah 2: Gunakan namespace jax_plugins atau siapkan entry_point
Ada dua opsi agar plugin Anda dapat ditemukan oleh JAX.
- Menggunakan paket namespace (ref). Tentukan modul yang unik secara global di bawah paket namespace
jax_plugins
(yaitu, cukup buat direktorijax_plugins
dan tentukan modul Anda di bawahnya). Berikut adalah contoh struktur direktori:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Menggunakan metadata paket (ref). Jika mem-build paket melalui pyproject.toml atau setup.py, iklankan nama modul plugin Anda dengan menyertakan titik entri di bagian grup
jax_plugins
yang mengarah ke nama modul lengkap Anda. Berikut adalah contoh melalui pyproject.toml atau setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Berikut adalah contoh cara openxla-pjrt-plugin diterapkan menggunakan Opsi 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Langkah 3: Terapkan metode initialize()
Anda perlu menerapkan metode initialize() di modul python untuk mendaftarkan plugin, misalnya:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Lihat di sini untuk mengetahui cara menggunakan xla_bridge.register_plugin
. Saat ini, metode ini bersifat pribadi. API publik akan dirilis pada masa mendatang.
Anda dapat menjalankan baris di bawah untuk memverifikasi bahwa plugin terdaftar dan menampilkan error jika tidak dapat dimuat.
jax.config.update("jax_platforms", "my_plugin")
JAX dapat memiliki beberapa backend/plugin. Ada beberapa opsi untuk memastikan plugin Anda digunakan sebagai backend default:
- Opsi 1: jalankan
jax.config.update("jax_platforms", "my_plugin")
di awal program. - Opsi 2: tetapkan ENV
JAX_PLATFORMS=my_plugin
. - Opsi 3: tetapkan prioritas yang cukup tinggi saat memanggil xb.register_plugin (nilai defaultnya adalah 400 yang lebih tinggi dari backend lain yang ada). Perhatikan bahwa backend dengan prioritas tertinggi hanya akan digunakan saat
JAX_PLATFORMS=''
. Nilai defaultJAX_PLATFORMS
adalah''
, tetapi terkadang akan ditimpa.
Cara menguji dengan JAX
Beberapa kasus pengujian dasar yang dapat dicoba:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Kami akan segera menambahkan petunjuk untuk menjalankan pengujian unit jax terhadap plugin Anda!)
Untuk mengetahui contoh plugin PJRT lainnya, lihat Contoh PJRT.