تركّز هذه المستندة على الاقتراحات حول كيفية الدمج مع PJRT، و كيفية اختبار دمج PJRT مع JAX.
كيفية الدمج مع PJRT
الخطوة 1: تنفيذ واجهة برمجة التطبيقات PJRT C API
الخيار "أ": يمكنك تنفيذ واجهة برمجة التطبيقات PJRT C API مباشرةً.
الخيار ب: إذا كان بإمكانك إنشاء الإصدار باستخدام رمز C++ في مستودع xla (من خلال إنشاء نسخة فرعية أو استخدام bazel)، يمكنك أيضًا تنفيذ واجهة برمجة التطبيقات PJRT C++ واستخدام حزمة C→C++:
- تنفيذ برنامج عميل PJRT لبرنامج C++ يرث من برنامج عميل PJRT الأساسي (وفصول PJRT ذات الصلة) في ما يلي بعض أمثلة عملاء PJRT بتنسيق C++: pjrt_stream_executor_client.h وtfrt_cpu_pjrt_client.h.
- تنفيذ بعض طرق واجهة برمجة التطبيقات C التي لا تشكّل جزءًا من برنامج C++ PJRT:
- PJRT_Client_Create. في ما يلي بعض الأمثلة على الرموز البرمجية الوصفية (على افتراض أنّ
GetPluginPjRtClient
يعرض عميل PJRT من C++ تم تنفيذه أعلاه):
- PJRT_Client_Create. في ما يلي بعض الأمثلة على الرموز البرمجية الوصفية (على افتراض أنّ
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
يُرجى العِلم أنّ PJRT_Client_Create يمكن أن يقبل خيارات تم تمريرها من إطار العمل. في ما يلي مثال على كيفية استخدام برنامج GPU لهذه الميزة.
- [اختيارية] PJRT_TopologyDescription_Create
- [اختياري] PJRT_Plugin_Initialize هذا الإعداد للإضافة يتم لمرة واحدة، وسيستدعيه إطار العمل قبل استدعاء أي دوال أخرى.
- [اختيارية] PJRT_Plugin_Attributes
باستخدام الغلاف، لن تحتاج إلى تنفيذ واجهات برمجة التطبيقات C المتبقية.
الخطوة 2: تنفيذ GetPjRtApi
عليك تنفيذ طريقة GetPjRtApi
تُعرِض PJRT_Api*
يحتوي على مؤشرات دوال إلى عمليات تنفيذ PJRT C API. في ما يلي مثال يفترض التنفيذ من خلال الغلاف (على غرار pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
الخطوة 3: اختبار عمليات تنفيذ واجهة برمجة التطبيقات C
يمكنك استدعاء RegisterPjRtCApiTestFactory لإجراء مجموعة صغيرة من الاختبارات لسلوكيات PJRT C API الأساسية.
كيفية استخدام مكوّن إضافي لـ PJRT من JAX
الخطوة 1: إعداد JAX
يمكنك استخدام إصدار JAX الليلي
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
في الوقت الحالي، عليك مطابقة إصدار jaxlib مع إصدار PJRT C API. يكفي عادةً استخدام إصدار jaxlib ليلي من اليوم نفسه الذي تم فيه إجراء عملية الدفع في TF التي يتم إنشاء المكوّن الإضافي وفقًا لها، على سبيل المثال:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
يمكنك أيضًا إنشاء مكتبة jaxlib من المصدر في الإصدار المحدّد من XLA الذي يتمّ إنشاء المكتبة استنادًا إليه (التعليمات).
سنبدأ قريبًا بتوفير التوافق مع ABI.
الخطوة 2: استخدام مساحة الاسم jax_plugins أو إعداد entry_point
هناك خياران لاكتشاف JAX للمكوّن الإضافي.
- استخدام حِزم مساحة الاسم (ref) حدِّد وحدة فريدة على مستوى العالم ضمن حزمة مساحة الاسم
jax_plugins
(أي ما عليك سوى إنشاء دليلjax_plugins
وتحديد وحدتك تحته). في ما يلي مثال على بنية الدليل:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- باستخدام البيانات الوصفية للحزمة (ref). في حال إنشاء حزمة من خلال pyproject.toml أو setup.py، يمكنك الإعلان عن اسم وحدة المكوّن الإضافي من خلال تضمين نقطة دخول ضمن مجموعة
jax_plugins
تشير إلى اسم الوحدة الكامل. في ما يلي مثال على ذلك باستخدام pyproject.toml أو setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
في ما يلي أمثلة على كيفية تنفيذ openxla-pjrt-plugin باستخدام الخيار 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
الخطوة 3: تنفيذ طريقة initialize()
عليك تنفيذ طريقة initialize() في وحدة Python لتسجيل المكوّن الإضافي، على سبيل المثال:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
يُرجى الاطّلاع على هذا الرابط لمعرفة كيفية استخدام xla_bridge.register_plugin
. وهي طريقة خاصة حاليًا. سيتم إصدار واجهة برمجة تطبيقات عامة في المستقبل.
يمكنك تنفيذ السطر أدناه للتأكّد من تسجيل المكوّن الإضافي وعرض خطأ في حال تعذّر تحميله.
jax.config.update("jax_platforms", "my_plugin")
قد يكون لدى JAX عدة أنظمة خلفية/مكوّنات إضافية. هناك بضعة خيارات لضمان استخدام المكوّن الإضافي كنظام أساسي تلقائي:
- الخيار 1: تنفيذ
jax.config.update("jax_platforms", "my_plugin")
في بداية البرنامج - الخيار الثاني: ضبط ENV
JAX_PLATFORMS=my_plugin
. - الخيار 3: ضبط أولوية عالية بما يكفي عند استدعاء xb.register_plugin (القيمة التلقائية هي 400 وهي أعلى من الخلفيات الحالية الأخرى). يُرجى العِلم أنّه لن يتم استخدام الخلفية ذات الأولوية القصوى إلا عند
JAX_PLATFORMS=''
. القيمة التلقائية للسياسةJAX_PLATFORMS
هي''
، ولكن سيتم أحيانًا استبدالها.
كيفية الاختبار باستخدام JAX
في ما يلي بعض حالات الاختبار الأساسية التي يمكنك تجربتها:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(سنضيف تعليمات لتشغيل اختبارات وحدة Jax على المكوّن الإضافي قريبًا).
لمزيد من الأمثلة على مكوّنات PJRT الإضافية، يُرجى الاطّلاع على أمثلة على مكوّنات PJRT الإضافية.