دمج المكوّن الإضافي PJRT

الخلفية

PJRT هي واجهة برمجة تطبيقات الجهاز الموحّدة التي نريد إضافتها إلى منظومة تعلُّم الآلة. وتتمثّل الرؤية طويلة المدى في ما يلي: (1) أن أطر العمل (JAX، وTF، وما إلى ذلك) ستستدعي PJRT، والتي تشمل عمليات تنفيذ خاصة بالأجهزة لا تتوافق مع أُطر العمل. و(2) يركّز كل جهاز على تنفيذ واجهات برمجة تطبيقات PJRT، ويمكن أن يكون معتمًا لأُطر العمل.

يركز هذا المستند على التوصيات حول كيفية الدمج مع PJRT، وكيفية اختبار دمج PJRT مع JAX.

كيفية الدمج مع PJRT

الخطوة 1: تنفيذ واجهة PJRT C API

الخيار أ: يمكنك تنفيذ واجهة برمجة التطبيقات PJRT C API مباشرةً.

الخيار "ب": إذا كان بإمكانك الإنشاء باستخدام كود C++ في xla repo (من خلال الشوكة أو bazel)، يمكنك أيضًا تنفيذ واجهة برمجة التطبيقات PJRT C++ واستخدام برنامج تضمين C→C++:

  1. نفِّذ برنامج C++ PJRT الذي يتم اكتسابه من عميل PJRT الأساسي (وفئات PJRT ذات الصلة). في ما يلي بعض الأمثلة على عميل C++ PJRT: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. تنفيذ بعض طرق واجهة برمجة التطبيقات C API التي لا تشكّل جزءًا من عميل C++ PJRT:
    • PJRT_Client_Create: في ما يلي بعض نماذج الرموز الزائفة (بافتراض أنّ GetPluginPjRtClient تعرض عميل C++ PJRT الذي تم تنفيذه أعلاه):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

ملاحظة: يمكن لـ PJRT_Client_Create أخذ الخيارات التي يتم تمريرها من إطار العمل. هنا مثال على كيفية استخدام برنامج وحدة معالجة الرسومات لهذه الميزة.

باستخدام برنامج تضمين، لن تحتاج إلى تنفيذ واجهات برمجة تطبيقات C المتبقية.

الخطوة 2: تنفيذ GetPjRtApi

يجب تنفيذ الطريقة GetPjRtApi التي تعرض PJRT_Api* تحتوي على مؤشرات الدوال إلى عمليات تنفيذ PJRT C API. في ما يلي مثال بافتراض التنفيذ من خلال برنامج تضمين (على غرار pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

الخطوة 3: اختبار عمليات تنفيذ C API

يمكنك استدعاء RegisterPjRtCApiTestFactory لإجراء مجموعة صغيرة من الاختبارات لمعرفة سلوكيات واجهة برمجة التطبيقات PJRT C API الأساسية.

كيفية استخدام مكوّن PJRT الإضافي من JAX

الخطوة 1: إعداد JAX

يمكنك استخدام JAX كل ليلة

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

أو إنشاء JAX من المصدر.

في الوقت الحالي، يجب مطابقة إصدار jaxlib مع إصدار PJRT C API. عادةً ما يكفي استخدام إصدار jaxlib في الليل من نفس يوم تنفيذ TF الذي تنشئ المكون الإضافي عليه، على سبيل المثال:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

يمكنك أيضًا إنشاء jaxlib من المصدر في التزام XLA الذي يتم إنشاؤه وفقًا له (instructions).

سنبدأ قريبًا بإتاحة التوافق مع واجهة التطبيق الثنائية (ABI).

الخطوة 2: استخدام مساحة اسم jax_Plugins أو إعداد enter_point

هناك خياران لاكتشاف المكون الإضافي بواسطة JAX.

  1. استخدام حزم مساحات الاسم (ref). حدِّد وحدة فريدة بشكل عام ضمن حزمة مساحة الاسم jax_plugins (أي أنشِئ دليل jax_plugins وحدِّد وحدتك أسفلها). في ما يلي مثال على بنية الدليل:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. استخدام البيانات الوصفية للحزمة (ref) عند إنشاء حزمة عبر pyproject.toml أو setup.py، يمكنك الإعلان عن اسم وحدة المكوّنات الإضافية عن طريق تضمين نقطة دخول ضمن مجموعة jax_plugins تشير إلى اسم الوحدة بالكامل. فيما يلي مثال عبر pyproject.toml أو setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

إليك أمثلة على كيفية تنفيذ openxla-pjrt-extension باستخدام الخيار 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

الخطوة 3: تنفيذ طريقة firstize()

تحتاج إلى تنفيذ طريقة firstize() في وحدة بايثون لتسجيل المكون الإضافي، على سبيل المثال:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

يُرجى الرجوع إلى هذه الصفحة للتعرّف على كيفية استخدام "xla_bridge.register_plugin". وهي حاليًا طريقة خاصة. سيتم إطلاق واجهة برمجة تطبيقات عامة في المستقبل.

يمكنك تشغيل السطر أدناه للتحقق من تسجيل المكوّن الإضافي وعرض رسالة خطأ إذا تعذّر تحميله.

jax.config.update("jax_platforms", "my_plugin")

قد يتضمّن JAX خلفيات أو مكوّنات إضافية متعدّدة. هناك عدد قليل من الخيارات لضمان استخدام المكوّن الإضافي كخلفية تلقائية:

  • الخيار 1: شغِّل jax.config.update("jax_platforms", "my_plugin") في بداية البرنامج.
  • الخيار 2: ضبط ENV JAX_PLATFORMS=my_plugin
  • الخيار 3: اضبط أولوية عالية بما فيه الكفاية عند طلب xb.register_extension (القيمة التلقائية 400 وهي أعلى من الخلفيات الحالية الأخرى). تجدر الإشارة إلى أنّه لن يتم استخدام الخلفية ذات الأولوية القصوى إلا عند JAX_PLATFORMS=''. القيمة التلقائية للسمة JAX_PLATFORMS هي ''، ولكن يتم استبدالها في بعض الأحيان.

كيفية الاختبار باستخدام JAX

إليك بعض حالات الاختبار الأساسية التي يمكنك تجربتها:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(سنضيف قريبًا تعليمات حول تشغيل اختبارات وحدة jax على المكوّن الإضافي الخاص بك!)

مثال: المكوّن الإضافي JAX CUDA

  1. تنفيذ PJRT C API من خلال برنامج تضمين (pjrt_c_api_gpu.h).
  2. عليك إعداد نقطة الدخول للحزمة (setup.py).
  3. نفِّذ طريقة originize() (__init__.py).
  4. يمكن اختباره مع أي اختبارات jax لـ CUDA. ```