PJRT प्लगिन इंटिग्रेशन

इस दस्तावेज़ में, PJRT के साथ इंटिग्रेट करने के तरीके और JAX के साथ PJRT इंटिग्रेशन की जांच करने के तरीके के सुझाव दिए गए हैं.

PJRT के साथ इंटिग्रेट करने का तरीका

पहला चरण: PJRT C API इंटरफ़ेस लागू करना

पहला विकल्प: PJRT C API को सीधे लागू किया जा सकता है.

दूसरा विकल्प: अगर आपके पास xla repo में C++ कोड के लिए, फ़ॉर्किंग या bazel की मदद से बिल्ड करने का विकल्प है, तो PJRT C++ API को लागू किया जा सकता है. साथ ही, C→C++ रैपर का इस्तेमाल किया जा सकता है:

  1. बेस PJRT क्लाइंट (और उससे जुड़ी PJRT क्लास) से इनहेरिट करने वाला C++ PJRT क्लाइंट लागू करें. यहां C++ PJRT क्लाइंट के कुछ उदाहरण दिए गए हैं: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. C API के कुछ ऐसे तरीके लागू करें जो C++ PJRT क्लाइंट का हिस्सा नहीं हैं:
    • PJRT_Client_Create. यहां कुछ सैंपल सूडो कोड दिए गए हैं. इसमें यह माना गया है कि GetPluginPjRtClient, ऊपर लागू किए गए C++ PJRT क्लाइंट को दिखाता है:
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

ध्यान दें कि PJRT_Client_Create फ़्रेमवर्क से मिले विकल्पों को इस्तेमाल कर सकता है. यहां एक उदाहरण दिया गया है, जिसमें बताया गया है कि जीपीयू क्लाइंट इस सुविधा का इस्तेमाल कैसे करता है.

  • [ज़रूरी नहीं] PJRT_TopologyDescription_Create.
  • [ज़रूरी नहीं] PJRT_Plugin_Initialize. यह प्लग इन एक बार सेट अप किया जाता है. किसी भी दूसरे फ़ंक्शन को कॉल करने से पहले, फ़्रेमवर्क इसे कॉल करेगा.
  • [ज़रूरी नहीं] PJRT_Plugin_Attributes.

रैपर की मदद से, आपको बाकी C API लागू करने की ज़रूरत नहीं है.

दूसरा चरण: GetPjRtApi लागू करना

आपको एक ऐसा तरीका GetPjRtApi लागू करना होगा जो PJRT C API के लागू होने के लिए, फ़ंक्शन पॉइंटर वाला PJRT_Api* दिखाता हो. यहां रैपर के ज़रिए लागू करने का उदाहरण दिया गया है. यह pjrt_c_api_cpu.cc जैसा ही है:

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

तीसरा चरण: C API को लागू करने की जांच करना

PJRT C API के बुनियादी व्यवहारों के लिए, टेस्ट का एक छोटा सेट चलाने के लिए, RegisterPjRtCApiTestFactory को कॉल किया जा सकता है.

JAX से PJRT प्लग इन इस्तेमाल करने का तरीका

पहला चरण: JAX सेट अप करना

JAX को हर रात इस्तेमाल किया जा सकता है

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

या सोर्स से JAX बनाएं.

फ़िलहाल, आपको jaxlib वर्शन को PJRT C API वर्शन से मैच करना होगा. आम तौर पर, उसी दिन के jaxlib के नाइटली वर्शन का इस्तेमाल करना काफ़ी होता है जिस दिन आपने TF के उस कमिट के लिए प्लग इन बनाया है जिसका इस्तेमाल किया जा रहा है. उदाहरण के लिए,

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

सोर्स से jaxlib को उसी XLA कमिट पर भी बनाया जा सकता है जिस पर आपका बिल्ड किया जा रहा है (निर्देश).

हम जल्द ही एबीआई के साथ काम करने की सुविधा उपलब्ध कराएंगे.

दूसरा चरण: jax_plugins नेमस्पेस का इस्तेमाल करना या entry_point सेट अप करना

JAX आपके प्लग इन को ढूंढने के लिए, दो विकल्प इस्तेमाल करता है.

  1. नेमस्पेस पैकेज (ref) का इस्तेमाल करना. jax_plugins नेमस्पेस पैकेज में, दुनिया भर में यूनीक मॉड्यूल तय करें. इसके लिए, बस एक jax_plugins डायरेक्ट्री बनाएं और उसके नीचे अपना मॉड्यूल तय करें. यहां डायरेक्ट्री स्ट्रक्चर का उदाहरण दिया गया है:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. पैकेज के मेटाडेटा (ref) का इस्तेमाल करना. अगर pyproject.toml या setup.py की मदद से पैकेज बनाया जा रहा है, तो jax_plugins ग्रुप में एंट्री-पॉइंट शामिल करके, अपने प्लग इन मॉड्यूल के नाम का विज्ञापन करें. यह एंट्री-पॉइंट, आपके मॉड्यूल के पूरे नाम पर ले जाता है. pyproject.toml या setup.py का इस्तेमाल करके, यहां एक उदाहरण दिया गया है:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

यहां दूसरे विकल्प का इस्तेमाल करके, openxla-pjrt-plugin को लागू करने के उदाहरण दिए गए हैं: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

तीसरा चरण: initialize() मेथड लागू करना

प्लग इन को रजिस्टर करने के लिए, आपको अपने Python मॉड्यूल में initialize() तरीका लागू करना होगा. उदाहरण के लिए:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin का इस्तेमाल करने का तरीका जानने के लिए, कृपया यहां जाएं. फ़िलहाल, यह एक निजी तरीका है. आने वाले समय में, एक सार्वजनिक एपीआई रिलीज़ किया जाएगा.

प्लग इन रजिस्टर है या नहीं, इसकी पुष्टि करने के लिए नीचे दी गई लाइन को चलाया जा सकता है. अगर प्लग इन लोड नहीं हो पाता है, तो गड़बड़ी की सूचना दी जा सकती है.

jax.config.update("jax_platforms", "my_plugin")

JAX में कई बैकएंड/प्लग इन हो सकते हैं. आपके प्लग इन को डिफ़ॉल्ट बैकएंड के तौर पर इस्तेमाल करने के लिए, ये कुछ विकल्प हैं:

  • पहला विकल्प: प्रोग्राम की शुरुआत में jax.config.update("jax_platforms", "my_plugin") चलाएं.
  • दूसरा विकल्प: ENV JAX_PLATFORMS=my_plugin सेट करें.
  • तीसरा विकल्प: xb.register_plugin को कॉल करते समय ज़्यादा प्राथमिकता सेट करें. डिफ़ॉल्ट वैल्यू 400 है, जो अन्य मौजूदा बैकएंड से ज़्यादा है. ध्यान दें कि सबसे ज़्यादा प्राथमिकता वाले बैकएंड का इस्तेमाल सिर्फ़ तब किया जाएगा, जब JAX_PLATFORMS=''. JAX_PLATFORMS की डिफ़ॉल्ट वैल्यू '' है, लेकिन कभी-कभी यह बदल जाएगी.

JAX की मदद से जांच करने का तरीका

आज़माने के लिए कुछ बुनियादी टेस्ट केस:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(हम जल्द ही आपके प्लग इन के लिए, jax यूनिट टेस्ट चलाने के निर्देश जोड़ेंगे!)

PJRT प्लग इन के ज़्यादा उदाहरणों के लिए, PJRT के उदाहरण देखें.