دمج المكوّن الإضافي PJRT

تركّز هذه المستندة على الاقتراحات حول كيفية الدمج مع PJRT، و كيفية اختبار دمج PJRT مع JAX.

كيفية الدمج مع PJRT

الخطوة 1: تنفيذ واجهة برمجة التطبيقات PJRT C API

الخيار "أ": يمكنك تنفيذ واجهة برمجة التطبيقات PJRT C API مباشرةً.

الخيار ب: إذا كان بإمكانك إنشاء الإصدار باستخدام رمز C++ في مستودع xla (من خلال إنشاء نسخة فرعية أو استخدام bazel)، يمكنك أيضًا تنفيذ واجهة برمجة التطبيقات PJRT C++ واستخدام حزمة C→C++:

  1. تنفيذ برنامج عميل PJRT لبرنامج C++ يرث من برنامج عميل PJRT الأساسي (وفصول PJRT ذات الصلة) في ما يلي بعض أمثلة عملاء PJRT بتنسيق C++: pjrt_stream_executor_client.h وtfrt_cpu_pjrt_client.h.
  2. تنفيذ بعض طرق واجهة برمجة التطبيقات C التي لا تشكّل جزءًا من برنامج C++ PJRT:
    • PJRT_Client_Create. في ما يلي بعض الأمثلة على الرموز البرمجية الوصفية (على افتراض أنّ GetPluginPjRtClient يعرض عميل PJRT من C++ تم تنفيذه أعلاه):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

يُرجى العِلم أنّ PJRT_Client_Create يمكن أن يقبل خيارات تم تمريرها من إطار العمل. في ما يلي مثال على كيفية استخدام برنامج GPU لهذه الميزة.

باستخدام الغلاف، لن تحتاج إلى تنفيذ واجهات برمجة التطبيقات C المتبقية.

الخطوة 2: تنفيذ GetPjRtApi

عليك تنفيذ طريقة GetPjRtApi تُعرِض PJRT_Api* يحتوي على مؤشرات دوال إلى عمليات تنفيذ PJRT C API. في ما يلي مثال يفترض التنفيذ من خلال الغلاف (على غرار pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

الخطوة 3: اختبار عمليات تنفيذ واجهة برمجة التطبيقات C

يمكنك استدعاء RegisterPjRtCApiTestFactory لإجراء مجموعة صغيرة من الاختبارات لسلوكيات PJRT C API الأساسية.

كيفية استخدام مكوّن إضافي لـ PJRT من JAX

الخطوة 1: إعداد JAX

يمكنك استخدام إصدار JAX الليلي

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

أو إنشاء JAX من المصدر.

في الوقت الحالي، عليك مطابقة إصدار jaxlib مع إصدار PJRT C API. يكفي عادةً استخدام إصدار jaxlib ليلي من اليوم نفسه الذي تم فيه إجراء عملية الدفع في TF التي يتم إنشاء المكوّن الإضافي وفقًا لها، على سبيل المثال:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

يمكنك أيضًا إنشاء مكتبة jaxlib من المصدر في الإصدار المحدّد من XLA الذي يتمّ إنشاء المكتبة استنادًا إليه (التعليمات).

سنبدأ قريبًا بتوفير التوافق مع ABI.

الخطوة 2: استخدام مساحة الاسم jax_plugins أو إعداد entry_point

هناك خياران لاكتشاف JAX للمكوّن الإضافي.

  1. استخدام حِزم مساحة الاسم (ref) حدِّد وحدة فريدة على مستوى العالم ضمن حزمة مساحة الاسم jax_plugins (أي ما عليك سوى إنشاء دليل jax_plugins وتحديد وحدتك تحته). في ما يلي مثال على بنية الدليل:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. باستخدام البيانات الوصفية للحزمة (ref). في حال إنشاء حزمة من خلال pyproject.toml أو setup.py، يمكنك الإعلان عن اسم وحدة المكوّن الإضافي من خلال تضمين نقطة دخول ضمن مجموعة jax_plugins تشير إلى اسم الوحدة الكامل. في ما يلي مثال على ذلك باستخدام pyproject.toml أو setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

في ما يلي أمثلة على كيفية تنفيذ openxla-pjrt-plugin باستخدام الخيار 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

الخطوة 3: تنفيذ طريقة initialize()‎

عليك تنفيذ طريقة initialize()‎ في وحدة Python لتسجيل المكوّن الإضافي، على سبيل المثال:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

يُرجى الاطّلاع على هذا الرابط لمعرفة كيفية استخدام xla_bridge.register_plugin. وهي طريقة خاصة حاليًا. سيتم إصدار واجهة برمجة تطبيقات عامة في المستقبل.

يمكنك تنفيذ السطر أدناه للتأكّد من تسجيل المكوّن الإضافي وعرض خطأ في حال تعذّر تحميله.

jax.config.update("jax_platforms", "my_plugin")

قد يكون لدى JAX عدة أنظمة خلفية/مكوّنات إضافية. هناك بضعة خيارات لضمان استخدام المكوّن الإضافي كنظام أساسي تلقائي:

  • الخيار 1: تنفيذ jax.config.update("jax_platforms", "my_plugin") في بداية البرنامج
  • الخيار الثاني: ضبط ENV JAX_PLATFORMS=my_plugin.
  • الخيار 3: ضبط أولوية عالية بما يكفي عند استدعاء xb.register_plugin (القيمة التلقائية هي 400 وهي أعلى من الخلفيات الحالية الأخرى). يُرجى العِلم أنّه لن يتم استخدام الخلفية ذات الأولوية القصوى إلا عند JAX_PLATFORMS=''. القيمة التلقائية للسياسة JAX_PLATFORMS هي ''، ولكن سيتم أحيانًا استبدالها.

كيفية الاختبار باستخدام JAX

في ما يلي بعض حالات الاختبار الأساسية التي يمكنك تجربتها:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(سنضيف تعليمات لتشغيل اختبارات وحدة Jax على المكوّن الإضافي قريبًا).

لمزيد من الأمثلة على مكوّنات PJRT الإضافية، يُرجى الاطّلاع على أمثلة على مكوّنات PJRT الإضافية.