PJRT प्लगिन इंटिग्रेशन

बैकग्राउंड

PJRT एक यूनिफ़ॉर्म Device API है, जिसे हम एमएल नेटवर्क में जोड़ना चाहते हैं. आने वाले समय के लिए, यह तरीका है: (1) फ़्रेमवर्क (JAX, TF वगैरह) में PJRT को कॉल किया जाएगा, जिसमें डिवाइस के हिसाब से लागू करने की प्रोसेस, फ़्रेमवर्क के लिए ओपेक नहीं होती हैं; (2) हर डिवाइस, PJRT एपीआई लागू करने पर फ़ोकस करता है और फ़्रेमवर्क के लिए ओपेक हो सकता है.

इस दस्तावेज़ में, PJRT के साथ इंटिग्रेट करने और JAX के साथ PJRT इंटिग्रेशन की जांच करने के तरीके के बारे में सुझाव दिए गए हैं.

PJRT के साथ कैसे इंटिग्रेट करें

पहला चरण: PJRT C API इंटरफ़ेस लागू करना

पहला विकल्प: PJRT C API को सीधे लागू किया जा सकता है.

दूसरा विकल्प: अगर xla repo (Forking या bazel के ज़रिए) में C++ कोड बनाने की सुविधा मौजूद है, तो PJRT C++ एपीआई को भी लागू किया जा सकता है और C→C++ रैपर का इस्तेमाल किया जा सकता है:

  1. base PJRT क्लाइंट (और उससे जुड़ी PJRT क्लास) से इनहेरिट किया गया C++ PJRT क्लाइंट लागू करें. यहां C++ PJRT क्लाइंट के कुछ उदाहरण दिए गए हैं: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. C API के कुछ ऐसे तरीके लागू करें जो C++ PJRT क्लाइंट का हिस्सा न हों:
    • PJRT_Client_Create का इस्तेमाल करें. नीचे pseudo code के कुछ सैंपल दिए गए हैं (यह मानते हुए कि GetPluginPjRtClient ऊपर लागू किया गया C++ PJRT क्लाइंट दिखाता है):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

ध्यान दें कि PJRT_Client_Create, फ़्रेमवर्क से पास किए गए विकल्पों को ले सकता है. जीपीयू क्लाइंट इस सुविधा का इस्तेमाल कैसे करता है, इसका उदाहरण यहां दिया गया है.

  • [ज़रूरी नहीं] PJRT_TopologyDescription_Create.
  • [ज़रूरी नहीं] PJRT_Plugin_Initialize. यह एक बार इस्तेमाल होने वाला प्लगिन सेटअप है. किसी दूसरे फ़ंक्शन को कॉल करने से पहले, फ़्रेमवर्क इसे कॉल करेगा.
  • [ज़रूरी नहीं] PJRT_Plugin_Attributes.

रैपर के साथ, आपको बाकी C API को लागू करने की ज़रूरत नहीं होती.

दूसरा चरण: GetPjRtApi लागू करना

आपको GetPjRtApi ऐसा तरीका लागू करना होगा जो PJRT C API को लागू करने के लिए, फ़ंक्शन पॉइंटर वाला PJRT_Api* दिखाए. नीचे एक उदाहरण दिया गया है. इसमें, रैपर की मदद से प्रोसेस करने का तरीका बताया गया है (pjrt_c_api_cpu.cc की तरह):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

तीसरा चरण: C API को लागू करने की जांच करना

अगर आपको PJRT C API के बुनियादी व्यवहार की जांच के लिए छोटा सेट चलाना है, तो RegisterPjRtCApiTestFactory को कॉल करें.

JAX के PJRT प्लगिन को इस्तेमाल करने का तरीका

पहला चरण: JAX सेट अप करना

आप एक रात में JAX का इस्तेमाल कर सकते हैं

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

या सोर्स से JAX बनाएं.

फ़िलहाल, आपको jaxlib के वर्शन और PJRT C API वर्शन का मिलान करना होगा. आम तौर पर, उसी दिन से jaxlib का jaxlib वर्शन इस्तेमाल करना काफ़ी होता है जिस दिन टीएफ़ के लिए तय किया गया वर्शन इस्तेमाल करके प्लगिन बनाया जा रहा है. उदाहरण के लिए,

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

आप जिस XLA प्रतिबद्धता के लिए बना रहे हैं (instructions) उसी पर सोर्स से एक jaxlib भी बनाया जा सकता है.

हम जल्द ही एबीआई के साथ काम करना शुरू करेंगे.

दूसरा चरण: jax_प्लगिन नेमस्पेस का इस्तेमाल करें या एंट्री_point सेट अप करें

JAX पर आपके प्लग इन को खोजने के दो विकल्प हैं.

  1. नेमस्पेस पैकेज का इस्तेमाल करना (ref). jax_plugins नेमस्पेस पैकेज में, दुनिया भर में इस्तेमाल होने वाला यूनीक मॉड्यूल तय करें. उदाहरण के लिए, सिर्फ़ jax_plugins डायरेक्ट्री बनाएं और उसके नीचे अपना मॉड्यूल तय करें. यहां डायरेक्ट्री के स्ट्रक्चर का एक उदाहरण दिया गया है:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. पैकेज मेटाडेटा (ref) का इस्तेमाल करना. अगर pyproject.toml या setup.py से पैकेज बना रहे हैं, तो अपने प्लगिन मॉड्यूल के नाम का विज्ञापन करें. इसके लिए, jax_plugins ग्रुप में एक ऐसा एंट्री-पॉइंट शामिल करें जो आपके पूरे मॉड्यूल के नाम की ओर इशारा करता हो. यहां pyproject.toml या setup.py का एक उदाहरण दिया गया है:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

यहां उदाहरण दिए गए हैं कि दूसरे विकल्प का इस्तेमाल करके Openxla-pjrt-plugin को कैसे लागू किया जाता है: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

तीसरा चरण: इनीशियलाइज़ेशन() तरीके को लागू करना

प्लग इन को रजिस्टर करने के लिए, आपको अपने Python मॉड्यूल में इनीशियलाइज़() तरीका लागू करना होगा, उदाहरण के लिए:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin के इस्तेमाल का तरीका जानने के लिए, कृपया यहां जाएं. फ़िलहाल, यह एक निजी तरीका है. आने वाले समय में, सार्वजनिक एपीआई को रिलीज़ किया जाएगा.

नीचे दी गई लाइन को देखकर यह पुष्टि की जा सकती है कि प्लगिन रजिस्टर हो गया है. साथ ही, अगर यह लोड न हो पाए, तो गड़बड़ी की जानकारी दें.

jax.config.update("jax_platforms", "my_plugin")

JAX में एक से ज़्यादा बैकएंड/प्लगिन हो सकते हैं. आपके प्लग इन का डिफ़ॉल्ट बैकएंड के रूप में इस्तेमाल किया जा रहा है या नहीं, यह पक्का करने के लिए यहां कुछ विकल्प दिए गए हैं:

  • पहला विकल्प: प्रोग्राम की शुरुआत में jax.config.update("jax_platforms", "my_plugin") चलाएं.
  • दूसरा विकल्प: ENV सेट करें JAX_PLATFORMS=my_plugin.
  • तीसरा विकल्प: xb.register_plugin को कॉल करते समय, काफ़ी ऊंची प्राथमिकता सेट करें (डिफ़ॉल्ट वैल्यू 400 है, जो दूसरे मौजूदा बैकएंड से ज़्यादा है). ध्यान दें कि सबसे ज़्यादा प्राथमिकता वाला बैकएंड, JAX_PLATFORMS='' होने पर ही इस्तेमाल किया जाएगा. JAX_PLATFORMS की डिफ़ॉल्ट वैल्यू '' है, लेकिन कभी-कभी यह वैल्यू बदल भी सकती है.

JAX की जांच करने का तरीका

आज़माने के लिए कुछ बुनियादी टेस्ट केस:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(हम जल्द ही आपके प्लगिन के लिए jax यूनिट टेस्ट चलाने के निर्देश जोड़ेंगे!)

उदाहरण: JAX CUDA प्लगिन

  1. रैपर (pjrt_c_api_gpu.h) के ज़रिए PJRT सी एपीआई लागू करना.
  2. पैकेज के लिए एंट्री पॉइंट सेट अप करें (setup.py).
  3. शुरू करें() तरीके (__init__.py) को लागू करें.
  4. इसकी जांच CUDA के लिए किसी भी jax जांच के साथ की जा सकती है. ```