बैकग्राउंड
PJRT एक यूनिफ़ॉर्म Device API है, जिसे हम एमएल नेटवर्क में जोड़ना चाहते हैं. आने वाले समय के लिए, यह तरीका है: (1) फ़्रेमवर्क (JAX, TF वगैरह) में PJRT को कॉल किया जाएगा, जिसमें डिवाइस के हिसाब से लागू करने की प्रोसेस, फ़्रेमवर्क के लिए ओपेक नहीं होती हैं; (2) हर डिवाइस, PJRT एपीआई लागू करने पर फ़ोकस करता है और फ़्रेमवर्क के लिए ओपेक हो सकता है.
इस दस्तावेज़ में, PJRT के साथ इंटिग्रेट करने और JAX के साथ PJRT इंटिग्रेशन की जांच करने के तरीके के बारे में सुझाव दिए गए हैं.
PJRT के साथ कैसे इंटिग्रेट करें
पहला चरण: PJRT C API इंटरफ़ेस लागू करना
पहला विकल्प: PJRT C API को सीधे लागू किया जा सकता है.
दूसरा विकल्प: अगर xla repo (Forking या bazel के ज़रिए) में C++ कोड बनाने की सुविधा मौजूद है, तो PJRT C++ एपीआई को भी लागू किया जा सकता है और C→C++ रैपर का इस्तेमाल किया जा सकता है:
- base PJRT क्लाइंट (और उससे जुड़ी PJRT क्लास) से इनहेरिट किया गया C++ PJRT क्लाइंट लागू करें. यहां C++ PJRT क्लाइंट के कुछ उदाहरण दिए गए हैं: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- C API के कुछ ऐसे तरीके लागू करें जो C++ PJRT क्लाइंट का हिस्सा न हों:
- PJRT_Client_Create का इस्तेमाल करें. नीचे pseudo code के कुछ सैंपल दिए गए हैं (यह मानते हुए कि
GetPluginPjRtClient
ऊपर लागू किया गया C++ PJRT क्लाइंट दिखाता है):
- PJRT_Client_Create का इस्तेमाल करें. नीचे pseudo code के कुछ सैंपल दिए गए हैं (यह मानते हुए कि
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
ध्यान दें कि PJRT_Client_Create, फ़्रेमवर्क से पास किए गए विकल्पों को ले सकता है. जीपीयू क्लाइंट इस सुविधा का इस्तेमाल कैसे करता है, इसका उदाहरण यहां दिया गया है.
- [ज़रूरी नहीं] PJRT_TopologyDescription_Create.
- [ज़रूरी नहीं] PJRT_Plugin_Initialize. यह एक बार इस्तेमाल होने वाला प्लगिन सेटअप है. किसी दूसरे फ़ंक्शन को कॉल करने से पहले, फ़्रेमवर्क इसे कॉल करेगा.
- [ज़रूरी नहीं] PJRT_Plugin_Attributes.
रैपर के साथ, आपको बाकी C API को लागू करने की ज़रूरत नहीं होती.
दूसरा चरण: GetPjRtApi लागू करना
आपको GetPjRtApi
ऐसा तरीका लागू करना होगा जो PJRT C API को लागू करने के लिए, फ़ंक्शन पॉइंटर वाला PJRT_Api*
दिखाए. नीचे एक उदाहरण दिया गया है. इसमें, रैपर की मदद से प्रोसेस करने का तरीका बताया गया है (pjrt_c_api_cpu.cc की तरह):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
तीसरा चरण: C API को लागू करने की जांच करना
अगर आपको PJRT C API के बुनियादी व्यवहार की जांच के लिए छोटा सेट चलाना है, तो RegisterPjRtCApiTestFactory को कॉल करें.
JAX के PJRT प्लगिन को इस्तेमाल करने का तरीका
पहला चरण: JAX सेट अप करना
आप एक रात में JAX का इस्तेमाल कर सकते हैं
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
फ़िलहाल, आपको jaxlib के वर्शन और PJRT C API वर्शन का मिलान करना होगा. आम तौर पर, उसी दिन से jaxlib का jaxlib वर्शन इस्तेमाल करना काफ़ी होता है जिस दिन टीएफ़ के लिए तय किया गया वर्शन इस्तेमाल करके प्लगिन बनाया जा रहा है. उदाहरण के लिए,
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
आप जिस XLA प्रतिबद्धता के लिए बना रहे हैं (instructions) उसी पर सोर्स से एक jaxlib भी बनाया जा सकता है.
हम जल्द ही एबीआई के साथ काम करना शुरू करेंगे.
दूसरा चरण: jax_प्लगिन नेमस्पेस का इस्तेमाल करें या एंट्री_point सेट अप करें
JAX पर आपके प्लग इन को खोजने के दो विकल्प हैं.
- नेमस्पेस पैकेज का इस्तेमाल करना (ref).
jax_plugins
नेमस्पेस पैकेज में, दुनिया भर में इस्तेमाल होने वाला यूनीक मॉड्यूल तय करें. उदाहरण के लिए, सिर्फ़jax_plugins
डायरेक्ट्री बनाएं और उसके नीचे अपना मॉड्यूल तय करें. यहां डायरेक्ट्री के स्ट्रक्चर का एक उदाहरण दिया गया है:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- पैकेज मेटाडेटा (ref) का इस्तेमाल करना. अगर pyproject.toml या setup.py से पैकेज बना रहे हैं, तो अपने प्लगिन मॉड्यूल के नाम का विज्ञापन करें. इसके लिए,
jax_plugins
ग्रुप में एक ऐसा एंट्री-पॉइंट शामिल करें जो आपके पूरे मॉड्यूल के नाम की ओर इशारा करता हो. यहां pyproject.toml या setup.py का एक उदाहरण दिया गया है:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
यहां उदाहरण दिए गए हैं कि दूसरे विकल्प का इस्तेमाल करके Openxla-pjrt-plugin को कैसे लागू किया जाता है: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
तीसरा चरण: इनीशियलाइज़ेशन() तरीके को लागू करना
प्लग इन को रजिस्टर करने के लिए, आपको अपने Python मॉड्यूल में इनीशियलाइज़() तरीका लागू करना होगा, उदाहरण के लिए:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
xla_bridge.register_plugin
के इस्तेमाल का तरीका जानने के लिए, कृपया यहां जाएं. फ़िलहाल, यह एक निजी तरीका है. आने वाले समय में, सार्वजनिक एपीआई को रिलीज़ किया जाएगा.
नीचे दी गई लाइन को देखकर यह पुष्टि की जा सकती है कि प्लगिन रजिस्टर हो गया है. साथ ही, अगर यह लोड न हो पाए, तो गड़बड़ी की जानकारी दें.
jax.config.update("jax_platforms", "my_plugin")
JAX में एक से ज़्यादा बैकएंड/प्लगिन हो सकते हैं. आपके प्लग इन का डिफ़ॉल्ट बैकएंड के रूप में इस्तेमाल किया जा रहा है या नहीं, यह पक्का करने के लिए यहां कुछ विकल्प दिए गए हैं:
- पहला विकल्प: प्रोग्राम की शुरुआत में
jax.config.update("jax_platforms", "my_plugin")
चलाएं. - दूसरा विकल्प: ENV सेट करें
JAX_PLATFORMS=my_plugin
. - तीसरा विकल्प: xb.register_plugin को कॉल करते समय, काफ़ी ऊंची प्राथमिकता सेट करें (डिफ़ॉल्ट वैल्यू 400 है, जो दूसरे मौजूदा बैकएंड से ज़्यादा है). ध्यान दें कि सबसे ज़्यादा प्राथमिकता वाला बैकएंड,
JAX_PLATFORMS=''
होने पर ही इस्तेमाल किया जाएगा.JAX_PLATFORMS
की डिफ़ॉल्ट वैल्यू''
है, लेकिन कभी-कभी यह वैल्यू बदल भी सकती है.
JAX की जांच करने का तरीका
आज़माने के लिए कुछ बुनियादी टेस्ट केस:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(हम जल्द ही आपके प्लगिन के लिए jax यूनिट टेस्ट चलाने के निर्देश जोड़ेंगे!)
उदाहरण: JAX CUDA प्लगिन
- रैपर (pjrt_c_api_gpu.h) के ज़रिए PJRT सी एपीआई लागू करना.
- पैकेज के लिए एंट्री पॉइंट सेट अप करें (setup.py).
- शुरू करें() तरीके (__init__.py) को लागू करें.
- इसकी जांच CUDA के लिए किसी भी jax जांच के साथ की जा सकती है. ```