این سند بر توصیههایی در مورد نحوه ادغام با PJRT و نحوه آزمایش ادغام PJRT با JAX تمرکز دارد.
نحوه ادغام با PJRT
مرحله 1: رابط PJRT C API را پیاده سازی کنید
گزینه A : می توانید PJRT C API را مستقیماً پیاده سازی کنید.
گزینه B : اگر میتوانید بر اساس کد C++ در مخزن xla بسازید (از طریق forking یا bazel)، همچنین میتوانید API PJRT C++ را پیادهسازی کنید و از wrapper C→C++ استفاده کنید:
- یک کلاینت C++ PJRT را که از کلاینت پایه PJRT (و کلاس های PJRT مرتبط) به ارث می برد، پیاده سازی کنید. در اینجا چند نمونه از کلاینت C++ PJRT آورده شده است: pjrt_stream_executor_client.h ، tfrt_cpu_pjrt_client.h .
- چند روش C API را که بخشی از کلاینت C++ PJRT نیستند، اجرا کنید:
- PJRT_Client_Create . در زیر چند کد شبه نمونه وجود دارد (با فرض اینکه
GetPluginPjRtClient
یک کلاینت C++ PJRT پیاده سازی شده در بالا را برمی گرداند):
- PJRT_Client_Create . در زیر چند کد شبه نمونه وجود دارد (با فرض اینکه
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
توجه داشته باشید که PJRT_Client_Create میتواند گزینههایی را که از فریمورک منتقل میشوند، بگیرد. در اینجا مثالی از نحوه استفاده یک کلاینت GPU از این ویژگی آورده شده است.
- [اختیاری] PJRT_TopologyDescription_Create .
- [اختیاری] PJRT_Plugin_Initialize . این یک راهاندازی پلاگین یکباره است که قبل از فراخوانی هر عملکرد دیگری توسط فریمورک فراخوانی میشود.
- [اختیاری] PJRT_Plugin_Attributes .
با استفاده از wrapper ، نیازی به پیاده سازی C API های باقی مانده ندارید.
مرحله 2: پیاده سازی GetPjRtApi
شما باید یک متد GetPjRtApi
پیاده سازی کنید که یک PJRT_Api*
حاوی نشانگرهای تابع را به پیاده سازی های PJRT C API برمی گرداند. در زیر یک مثال با فرض پیاده سازی از طریق wrapper (مشابه pjrt_c_api_cpu.cc ) آورده شده است:
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
مرحله 3: اجرای C API را آزمایش کنید
می توانید با RegisterPjRtCApiTestFactory تماس بگیرید تا مجموعه کوچکی از تست ها را برای رفتارهای اساسی PJRT C API اجرا کنید.
نحوه استفاده از پلاگین PJRT از JAX
مرحله 1: JAX را راه اندازی کنید
می توانید شبانه از JAX استفاده کنید
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
در حال حاضر، باید نسخه jaxlib را با نسخه PJRT C API مطابقت دهید. معمولاً کافی است از یک نسخه شبانه jaxlib از همان روزی که در TF commit که پلاگین خود را با آن میسازید استفاده کنید، به عنوان مثال،
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
شما همچنین می توانید یک jaxlib را از منبع دقیقاً در commit XLA که در برابر آن می سازید بسازید ( دستورالعمل ).
ما به زودی پشتیبانی از سازگاری ABI را آغاز خواهیم کرد.
مرحله 2: از فضای نام jax_plugins یا راه اندازی enter_point استفاده کنید
دو گزینه برای کشف افزونه شما توسط JAX وجود دارد.
- استفاده از بسته های فضای نام ( رجوع کنید ). یک ماژول منحصر به فرد جهانی را در بسته فضای نام
jax_plugins
تعریف کنید (یعنی فقط یک دایرکتوریjax_plugins
ایجاد کنید و ماژول خود را در زیر آن تعریف کنید). در اینجا یک نمونه ساختار دایرکتوری آمده است:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- استفاده از فراداده بسته ( رجوع کنید ). اگر بسته ای را از طریق pyproject.toml یا setup.py می سازید، نام ماژول پلاگین خود را با قرار دادن یک نقطه ورودی در زیر گروه
jax_plugins
که به نام کامل ماژول شما اشاره می کند، تبلیغ کنید. در اینجا یک مثال از طریق pyproject.toml یا setup.py آورده شده است:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
در اینجا نمونه هایی از نحوه پیاده سازی openxla-pjrt-plugin با استفاده از گزینه 2 آورده شده است: https://github.com/openxla/openxla-pjrt-plugin/pull/119، https://github.com/openxla/openxla-pjrt- plugin/pull/120
مرحله 3: یک متد ()initialize را پیاده سازی کنید
برای ثبت پلاگین باید یک متد ()initialize را در ماژول پایتون خود پیاده سازی کنید، به عنوان مثال:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
لطفاً در مورد نحوه استفاده از xla_bridge.register_plugin
به اینجا مراجعه کنید. در حال حاضر یک روش خصوصی است. یک API عمومی در آینده منتشر خواهد شد.
میتوانید خط زیر را اجرا کنید تا تأیید کنید که افزونه ثبت شده است و اگر بارگذاری نشد، خطایی ایجاد کنید.
jax.config.update("jax_platforms", "my_plugin")
JAX ممکن است چندین بکاند/پلاگین داشته باشد. چند گزینه برای اطمینان از استفاده از افزونه شما به عنوان پشتیبان پیش فرض وجود دارد:
- گزینه 1:
jax.config.update("jax_platforms", "my_plugin")
را در ابتدای برنامه اجرا کنید. - گزینه 2: ENV
JAX_PLATFORMS=my_plugin
را تنظیم کنید. - گزینه 3: هنگام فراخوانی xb.register_plugin یک اولویت به اندازه کافی بالا تنظیم کنید (مقدار پیش فرض 400 است که بالاتر از سایر باطن های موجود است). توجه داشته باشید که باطن با بالاترین اولویت فقط زمانی استفاده می شود که
JAX_PLATFORMS=''
. مقدار پیش فرضJAX_PLATFORMS
''
است اما گاهی اوقات بازنویسی می شود.
چگونه با JAX تست کنیم
چند مورد آزمایشی اساسی که باید امتحان کنید:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(به زودی دستورالعملهایی برای اجرای تستهای jax unit در برابر افزونه شما اضافه خواهیم کرد!)
برای نمونههای بیشتر از پلاگینهای PJRT به مثالهای PJRT مراجعه کنید.