الخلفية
PJRT هي واجهة برمجة تطبيقات الجهاز الموحّدة التي نريد إضافتها إلى منظومة تعلُّم الآلة. وتتمثّل الرؤية طويلة المدى في ما يلي: (1) أن أطر العمل (JAX، وTF، وما إلى ذلك) ستستدعي PJRT، والتي تشمل عمليات تنفيذ خاصة بالأجهزة لا تتوافق مع أُطر العمل. و(2) يركّز كل جهاز على تنفيذ واجهات برمجة تطبيقات PJRT، ويمكن أن يكون معتمًا لأُطر العمل.
يركز هذا المستند على التوصيات حول كيفية الدمج مع PJRT، وكيفية اختبار دمج PJRT مع JAX.
كيفية الدمج مع PJRT
الخطوة 1: تنفيذ واجهة PJRT C API
الخيار أ: يمكنك تنفيذ واجهة برمجة التطبيقات PJRT C API مباشرةً.
الخيار "ب": إذا كان بإمكانك الإنشاء باستخدام كود C++ في xla repo (من خلال الشوكة أو bazel)، يمكنك أيضًا تنفيذ واجهة برمجة التطبيقات PJRT C++ واستخدام برنامج تضمين C→C++:
- نفِّذ برنامج C++ PJRT الذي يتم اكتسابه من عميل PJRT الأساسي (وفئات PJRT ذات الصلة). في ما يلي بعض الأمثلة على عميل C++ PJRT: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- تنفيذ بعض طرق واجهة برمجة التطبيقات C API التي لا تشكّل جزءًا من عميل C++ PJRT:
- PJRT_Client_Create: في ما يلي بعض نماذج الرموز الزائفة (بافتراض أنّ
GetPluginPjRtClient
تعرض عميل C++ PJRT الذي تم تنفيذه أعلاه):
- PJRT_Client_Create: في ما يلي بعض نماذج الرموز الزائفة (بافتراض أنّ
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
ملاحظة: يمكن لـ PJRT_Client_Create أخذ الخيارات التي يتم تمريرها من إطار العمل. هنا مثال على كيفية استخدام برنامج وحدة معالجة الرسومات لهذه الميزة.
- [اختياري] PJRT_TopologyDescription_Create.
- [اختياري] PJRT_Plugin_Initialize هذا إعداد للمكوّن الإضافي لمرة واحدة، والذي سيتم استدعاؤه بواسطة إطار العمل قبل استدعاء أي دوال أخرى.
- [اختياري] PJRT_Plugin_Attributes:
باستخدام برنامج تضمين، لن تحتاج إلى تنفيذ واجهات برمجة تطبيقات C المتبقية.
الخطوة 2: تنفيذ GetPjRtApi
يجب تنفيذ الطريقة GetPjRtApi
التي تعرض PJRT_Api*
تحتوي على مؤشرات الدوال إلى عمليات تنفيذ PJRT C API. في ما يلي مثال بافتراض التنفيذ من خلال برنامج تضمين (على غرار pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
الخطوة 3: اختبار عمليات تنفيذ C API
يمكنك استدعاء RegisterPjRtCApiTestFactory لإجراء مجموعة صغيرة من الاختبارات لمعرفة سلوكيات واجهة برمجة التطبيقات PJRT C API الأساسية.
كيفية استخدام مكوّن PJRT الإضافي من JAX
الخطوة 1: إعداد JAX
يمكنك استخدام JAX كل ليلة
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
في الوقت الحالي، يجب مطابقة إصدار jaxlib مع إصدار PJRT C API. عادةً ما يكفي استخدام إصدار jaxlib في الليل من نفس يوم تنفيذ TF الذي تنشئ المكون الإضافي عليه، على سبيل المثال:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
يمكنك أيضًا إنشاء jaxlib من المصدر في التزام XLA الذي يتم إنشاؤه وفقًا له (instructions).
سنبدأ قريبًا بإتاحة التوافق مع واجهة التطبيق الثنائية (ABI).
الخطوة 2: استخدام مساحة اسم jax_Plugins أو إعداد enter_point
هناك خياران لاكتشاف المكون الإضافي بواسطة JAX.
- استخدام حزم مساحات الاسم (ref). حدِّد وحدة فريدة بشكل عام ضمن حزمة مساحة الاسم
jax_plugins
(أي أنشِئ دليلjax_plugins
وحدِّد وحدتك أسفلها). في ما يلي مثال على بنية الدليل:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- استخدام البيانات الوصفية للحزمة (ref) عند إنشاء حزمة عبر pyproject.toml أو setup.py، يمكنك الإعلان عن اسم وحدة المكوّنات الإضافية عن طريق تضمين نقطة دخول ضمن مجموعة
jax_plugins
تشير إلى اسم الوحدة بالكامل. فيما يلي مثال عبر pyproject.toml أو setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
إليك أمثلة على كيفية تنفيذ openxla-pjrt-extension باستخدام الخيار 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
الخطوة 3: تنفيذ طريقة firstize()
تحتاج إلى تنفيذ طريقة firstize() في وحدة بايثون لتسجيل المكون الإضافي، على سبيل المثال:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
يُرجى الرجوع إلى هذه الصفحة للتعرّف على كيفية استخدام "xla_bridge.register_plugin
". وهي حاليًا طريقة خاصة. سيتم إطلاق واجهة برمجة تطبيقات عامة في المستقبل.
يمكنك تشغيل السطر أدناه للتحقق من تسجيل المكوّن الإضافي وعرض رسالة خطأ إذا تعذّر تحميله.
jax.config.update("jax_platforms", "my_plugin")
قد يتضمّن JAX خلفيات أو مكوّنات إضافية متعدّدة. هناك عدد قليل من الخيارات لضمان استخدام المكوّن الإضافي كخلفية تلقائية:
- الخيار 1: شغِّل
jax.config.update("jax_platforms", "my_plugin")
في بداية البرنامج. - الخيار 2: ضبط ENV
JAX_PLATFORMS=my_plugin
- الخيار 3: اضبط أولوية عالية بما فيه الكفاية عند طلب xb.register_extension (القيمة التلقائية 400 وهي أعلى من الخلفيات الحالية الأخرى). تجدر الإشارة إلى أنّه لن يتم استخدام الخلفية ذات الأولوية القصوى إلا عند
JAX_PLATFORMS=''
. القيمة التلقائية للسمةJAX_PLATFORMS
هي''
، ولكن يتم استبدالها في بعض الأحيان.
كيفية الاختبار باستخدام JAX
إليك بعض حالات الاختبار الأساسية التي يمكنك تجربتها:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(سنضيف قريبًا تعليمات حول تشغيل اختبارات وحدة jax على المكوّن الإضافي الخاص بك!)
مثال: المكوّن الإضافي JAX CUDA
- تنفيذ PJRT C API من خلال برنامج تضمين (pjrt_c_api_gpu.h).
- عليك إعداد نقطة الدخول للحزمة (setup.py).
- نفِّذ طريقة originize() (__init__.py).
- يمكن اختباره مع أي اختبارات jax لـ CUDA. ```