PJRT প্লাগইন ইন্টিগ্রেশন

পটভূমি

PJRT হল ইউনিফর্ম ডিভাইস API যা আমরা ML ইকোসিস্টেমে যোগ করতে চাই। দীর্ঘমেয়াদী দৃষ্টিভঙ্গি হল: (1) ফ্রেমওয়ার্ক (JAX, TF, ইত্যাদি) PJRT কে কল করবে, যার ডিভাইস-নির্দিষ্ট বাস্তবায়ন রয়েছে যা কাঠামোর জন্য অস্বচ্ছ; (2) প্রতিটি ডিভাইস PJRT API বাস্তবায়নে ফোকাস করে এবং ফ্রেমওয়ার্কের জন্য অস্বচ্ছ হতে পারে।

এই ডকটি কীভাবে PJRT-এর সাথে একীভূত করা যায় এবং JAX-এর সাথে PJRT একীকরণ কীভাবে পরীক্ষা করা যায় সে সম্পর্কে সুপারিশগুলির উপর ফোকাস করে৷

কিভাবে PJRT এর সাথে একীভূত করা যায়

ধাপ 1: PJRT C API ইন্টারফেস প্রয়োগ করুন

বিকল্প A : আপনি সরাসরি PJRT C API প্রয়োগ করতে পারেন।

বিকল্প B : আপনি যদি xla রেপোতে (ফর্কিং বা বেজেলের মাধ্যমে) C++ কোডের বিপরীতে তৈরি করতে সক্ষম হন, তাহলে আপনি PJRT C++ API প্রয়োগ করতে পারেন এবং C→C++ র‌্যাপার ব্যবহার করতে পারেন:

  1. বেস PJRT ক্লায়েন্ট (এবং সম্পর্কিত PJRT ক্লাস) থেকে উত্তরাধিকারসূত্রে পাওয়া একটি C++ PJRT ক্লায়েন্ট বাস্তবায়ন করুন। এখানে C++ PJRT ক্লায়েন্টের কিছু উদাহরণ রয়েছে: pjrt_stream_executor_client.h , tfrt_cpu_pjrt_client.h
  2. C++ PJRT ক্লায়েন্টের অংশ নয় এমন কয়েকটি C API পদ্ধতি প্রয়োগ করুন:
    • PJRT_Client_Create । নীচে কিছু নমুনা সিউডো কোড দেওয়া হল (ধরে নিলাম GetPluginPjRtClient উপরে প্রয়োগ করা একটি C++ PJRT ক্লায়েন্ট প্রদান করে):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

নোট PJRT_Client_Create ফ্রেমওয়ার্ক থেকে পাস করা বিকল্পগুলি নিতে পারে। একটি GPU ক্লায়েন্ট কীভাবে এই বৈশিষ্ট্যটি ব্যবহার করে তার একটি উদাহরণ এখানে রয়েছে।

  • [ঐচ্ছিক] PJRT_TopologyDescription_Create
  • [ঐচ্ছিক] PJRT_Plugin_Initialize । এটি একটি এককালীন প্লাগইন সেটআপ, যা অন্য কোনো ফাংশন কল করার আগে ফ্রেমওয়ার্ক দ্বারা কল করা হবে।
  • [ঐচ্ছিক] PJRT_Plugin_Attributes

র‍্যাপারের সাথে, আপনাকে অবশিষ্ট C API বাস্তবায়ন করতে হবে না।

ধাপ 2: GetPjRtApi প্রয়োগ করুন

আপনাকে GetPjRtApi একটি পদ্ধতি প্রয়োগ করতে হবে যা PJRT C API বাস্তবায়নে ফাংশন পয়েন্টার ধারণকারী PJRT_Api* প্রদান করে। নীচে একটি উদাহরণ রয়েছে যা র‍্যাপারের মাধ্যমে বাস্তবায়নের অনুরূপ ( pjrt_c_api_cpu.cc এর অনুরূপ):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

ধাপ 3: পরীক্ষা C API বাস্তবায়ন

মৌলিক PJRT C API আচরণের জন্য পরীক্ষার একটি ছোট সেট চালানোর জন্য আপনি RegisterPjRtCApiTestFactory কল করতে পারেন।

কিভাবে JAX থেকে একটি PJRT প্লাগইন ব্যবহার করবেন

ধাপ 1: JAX সেট আপ করুন

আপনি হয় রাতে JAX ব্যবহার করতে পারেন

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

অথবা উৎস থেকে JAX তৈরি করুন

আপাতত, আপনাকে জ্যাক্সলিব সংস্করণের সাথে PJRT C API সংস্করণের সাথে মেলাতে হবে। টিএফ যে প্রতিশ্রুতির বিরুদ্ধে আপনি আপনার প্লাগইন তৈরি করছেন, সেই দিন থেকে সাধারণত জ্যাক্সলিব রাতের সংস্করণ ব্যবহার করা যথেষ্ট, যেমন

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

আপনি ঠিক XLA প্রতিশ্রুতিতে উত্স থেকে একটি জ্যাক্সলিব তৈরি করতে পারেন যা আপনি ( নির্দেশাবলী ) এর বিরুদ্ধে তৈরি করছেন।

আমরা শীঘ্রই ABI সামঞ্জস্যতা সমর্থন শুরু করব।

ধাপ 2: jax_plugins নামস্থান ব্যবহার করুন বা entry_point সেট আপ করুন

JAX দ্বারা আপনার প্লাগইন আবিষ্কার করার জন্য দুটি বিকল্প রয়েছে৷

  1. নামস্থান প্যাকেজ ব্যবহার করে ( রেফ )। jax_plugins নেমস্পেস প্যাকেজের অধীনে একটি বিশ্বব্যাপী অনন্য মডিউল সংজ্ঞায়িত করুন (অর্থাৎ শুধু একটি jax_plugins ডিরেক্টরি তৈরি করুন এবং এর নীচে আপনার মডিউলটি সংজ্ঞায়িত করুন)। এখানে একটি উদাহরণ ডিরেক্টরি গঠন আছে:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. প্যাকেজ মেটাডেটা ব্যবহার করে ( রেফ )। যদি pyproject.toml বা setup.py এর মাধ্যমে একটি প্যাকেজ তৈরি করেন, তাহলে jax_plugins গ্রুপের অধীনে একটি এন্ট্রি-পয়েন্ট অন্তর্ভুক্ত করে আপনার প্লাগইন মডিউল নামের বিজ্ঞাপন দিন যা আপনার সম্পূর্ণ মডিউল নামের দিকে নির্দেশ করে। এখানে pyproject.toml বা setup.py এর মাধ্যমে একটি উদাহরণ রয়েছে:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

বিকল্প 2 ব্যবহার করে openxla-pjrt-plugin কীভাবে প্রয়োগ করা হয় তার উদাহরণ এখানে রয়েছে: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt- প্লাগইন/টান/120

ধাপ 3: একটি ইনিশিয়ালাইজ() পদ্ধতি প্রয়োগ করুন

প্লাগইনটি নিবন্ধন করার জন্য আপনাকে আপনার পাইথন মডিউলে একটি ইনিশিয়ালাইজ() পদ্ধতি প্রয়োগ করতে হবে, উদাহরণস্বরূপ:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

কিভাবে xla_bridge.register_plugin ব্যবহার করতে হয় সে সম্পর্কে অনুগ্রহ করে এখানে পড়ুন। এটি বর্তমানে একটি ব্যক্তিগত পদ্ধতি। ভবিষ্যতে একটি পাবলিক API প্রকাশ করা হবে।

প্লাগইনটি নিবন্ধিত হয়েছে কিনা তা যাচাই করতে আপনি নীচের লাইনটি চালাতে পারেন এবং এটি লোড করা না হলে একটি ত্রুটি উত্থাপন করতে পারেন৷

jax.config.update("jax_platforms", "my_plugin")

JAX এর একাধিক ব্যাকএন্ড/প্লাগইন থাকতে পারে। আপনার প্লাগইনটি ডিফল্ট ব্যাকএন্ড হিসাবে ব্যবহার করা হয়েছে তা নিশ্চিত করার জন্য কয়েকটি বিকল্প রয়েছে:

  • বিকল্প 1: প্রোগ্রামের শুরুতে jax.config.update("jax_platforms", "my_plugin") চালান।
  • বিকল্প 2: ENV JAX_PLATFORMS=my_plugin সেট করুন।
  • বিকল্প 3: xb.register_plugin কল করার সময় একটি উচ্চ যথেষ্ট অগ্রাধিকার সেট করুন (ডিফল্ট মান হল 400 যা অন্যান্য বিদ্যমান ব্যাকএন্ডের চেয়ে বেশি)। নোট করুন সর্বোচ্চ অগ্রাধিকার সহ ব্যাকএন্ড শুধুমাত্র তখনই ব্যবহার করা হবে যখন JAX_PLATFORMS=''JAX_PLATFORMS এর ডিফল্ট মান হল '' কিন্তু কখনও কখনও এটি ওভাররাইট হয়ে যাবে৷

কিভাবে JAX দিয়ে পরীক্ষা করবেন

কিছু মৌলিক পরীক্ষার ক্ষেত্রে চেষ্টা করার জন্য:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(আমরা শীঘ্রই আপনার প্লাগইনের বিরুদ্ধে জ্যাক্স ইউনিট পরীক্ষা চালানোর জন্য নির্দেশাবলী যোগ করব!)

উদাহরণ: JAX CUDA প্লাগইন

  1. মোড়কের মাধ্যমে PJRT C API বাস্তবায়ন ( pjrt_c_api_gpu.h )।
  2. প্যাকেজের জন্য এন্ট্রি পয়েন্ট সেট আপ করুন ( setup.py )।
  3. একটি ইনিশিয়ালাইজ() পদ্ধতি প্রয়োগ করুন ( __init__.py )।
  4. CUDA-এর জন্য যেকোন জ্যাক্স পরীক্ষা দিয়ে পরীক্ষা করা যেতে পারে। ```