ادغام پلاگین PJRT

زمینه

PJRT یک Device API است که می خواهیم به اکوسیستم ML اضافه کنیم. چشم انداز بلندمدت این است که: (1) فریم ورک ها (JAX، TF، و غیره) PJRT را فراخوانی می کنند، که پیاده سازی های مخصوص دستگاه دارد که نسبت به فریم ورک ها غیر شفاف هستند. (2) هر دستگاه بر پیاده سازی API های PJRT تمرکز دارد و می تواند نسبت به چارچوب ها غیر شفاف باشد.

این سند بر توصیه‌هایی در مورد نحوه ادغام با PJRT و نحوه آزمایش ادغام PJRT با JAX تمرکز دارد.

نحوه ادغام با PJRT

مرحله 1: رابط PJRT C API را پیاده سازی کنید

گزینه A : می توانید PJRT C API را مستقیماً پیاده سازی کنید.

گزینه B : اگر می‌توانید بر اساس کد C++ در مخزن xla بسازید (از طریق forking یا bazel)، همچنین می‌توانید API PJRT C++ را پیاده‌سازی کنید و از wrapper C→C++ استفاده کنید:

  1. یک کلاینت C++ PJRT را که از کلاینت پایه PJRT (و کلاس های PJRT مرتبط) به ارث می برد، پیاده سازی کنید. در اینجا چند نمونه از کلاینت C++ PJRT آورده شده است: pjrt_stream_executor_client.h ، tfrt_cpu_pjrt_client.h .
  2. چند روش C API را که بخشی از کلاینت C++ PJRT نیستند، اجرا کنید:
    • PJRT_Client_Create . در زیر چند کد شبه نمونه وجود دارد (با فرض اینکه GetPluginPjRtClient یک کلاینت C++ PJRT پیاده سازی شده در بالا را برمی گرداند):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

توجه داشته باشید که PJRT_Client_Create می‌تواند گزینه‌هایی را که از فریم‌ورک منتقل می‌شوند، بگیرد. در اینجا مثالی از نحوه استفاده یک کلاینت GPU از این ویژگی آورده شده است.

با استفاده از wrapper ، نیازی به پیاده سازی C API های باقی مانده ندارید.

مرحله 2: پیاده سازی GetPjRtApi

شما باید یک متد GetPjRtApi را پیاده سازی کنید که یک PJRT_Api* حاوی نشانگرهای تابع را به پیاده سازی های PJRT C API برمی گرداند. در زیر یک مثال با فرض پیاده سازی از طریق wrapper (مشابه pjrt_c_api_cpu.cc ) آورده شده است:

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

مرحله 3: اجرای C API را آزمایش کنید

می توانید با RegisterPjRtCApiTestFactory تماس بگیرید تا مجموعه کوچکی از تست ها را برای رفتارهای اساسی PJRT C API اجرا کنید.

نحوه استفاده از پلاگین PJRT از JAX

مرحله 1: JAX را راه اندازی کنید

می توانید شبانه از JAX استفاده کنید

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

یا JAX را از منبع بسازید .

در حال حاضر، باید نسخه jaxlib را با نسخه PJRT C API مطابقت دهید. معمولاً کافی است از یک نسخه شبانه jaxlib از همان روزی که در TF commit که پلاگین خود را با آن می‌سازید استفاده کنید، به عنوان مثال،

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

شما همچنین می توانید یک jaxlib را از منبع دقیقاً در commit XLA که در برابر آن می سازید بسازید ( دستورالعمل ).

ما به زودی از سازگاری ABI پشتیبانی خواهیم کرد.

مرحله 2: از فضای نام jax_plugins یا راه اندازی enter_point استفاده کنید

دو گزینه برای کشف افزونه شما توسط JAX وجود دارد.

  1. استفاده از بسته های فضای نام ( رجوع کنید ). یک ماژول منحصر به فرد جهانی را در بسته فضای نام jax_plugins تعریف کنید (یعنی فقط یک دایرکتوری jax_plugins ایجاد کنید و ماژول خود را در زیر آن تعریف کنید). در اینجا یک نمونه ساختار دایرکتوری آمده است:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. استفاده از فراداده بسته ( رجوع کنید ). اگر بسته ای را از طریق pyproject.toml یا setup.py می سازید، نام ماژول پلاگین خود را با قرار دادن یک نقطه ورودی در زیر گروه jax_plugins که به نام کامل ماژول شما اشاره می کند، تبلیغ کنید. در اینجا یک مثال از طریق pyproject.toml یا setup.py آورده شده است:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

در اینجا نمونه هایی از نحوه پیاده سازی openxla-pjrt-plugin با استفاده از گزینه 2 آورده شده است: https://github.com/openxla/openxla-pjrt-plugin/pull/119، https://github.com/openxla/openxla-pjrt- plugin/pull/120

مرحله 3: یک متد ()initialize را پیاده سازی کنید

برای ثبت پلاگین باید یک متد ()initialize را در ماژول پایتون خود پیاده سازی کنید، به عنوان مثال:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

لطفاً در مورد نحوه استفاده از xla_bridge.register_plugin به اینجا مراجعه کنید. در حال حاضر یک روش خصوصی است. یک API عمومی در آینده منتشر خواهد شد.

می‌توانید خط زیر را اجرا کنید تا تأیید کنید که افزونه ثبت شده است و اگر بارگذاری نشد، خطایی ایجاد کنید.

jax.config.update("jax_platforms", "my_plugin")

JAX ممکن است چندین بک‌اند/پلاگین داشته باشد. چند گزینه برای اطمینان از استفاده از افزونه شما به عنوان پشتیبان پیش فرض وجود دارد:

  • گزینه 1: jax.config.update("jax_platforms", "my_plugin") را در ابتدای برنامه اجرا کنید.
  • گزینه 2: ENV JAX_PLATFORMS=my_plugin را تنظیم کنید.
  • گزینه 3: هنگام فراخوانی xb.register_plugin یک اولویت به اندازه کافی بالا تنظیم کنید (مقدار پیش فرض 400 است که بالاتر از سایر باطن های موجود است). توجه داشته باشید که باطن با بالاترین اولویت فقط زمانی استفاده می شود که JAX_PLATFORMS='' . مقدار پیش فرض JAX_PLATFORMS '' است اما گاهی اوقات بازنویسی می شود.

چگونه با JAX تست کنیم

چند مورد آزمایشی اساسی که باید امتحان کنید:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(به زودی دستورالعمل‌هایی برای اجرای تست‌های jax unit در برابر افزونه شما اضافه خواهیم کرد!)

مثال: افزونه JAX CUDA

  1. اجرای PJRT C API از طریق wrapper ( pjrt_c_api_gpu.h ).
  2. نقطه ورود بسته را تنظیم کنید ( setup.py ).
  3. یک متد ()initialize ( __init__.py ) را پیاده سازی کنید.
  4. می توان با هر تست جکس برای CUDA تست کرد. ```