बैकग्राउंड
PJRT एक यूनिफ़ॉर्म डिवाइस एपीआई है, जिसे हमें एमएल (मशीन लर्निंग) नेटवर्क के साथ जोड़ना है. लंबे समय के लिए, हमारा मकसद यह है:
- फ़्रेमवर्क (JAX, TF वगैरह) PJRT को कॉल करेंगे. इसमें डिवाइस के हिसाब से लागू किए गए ऐसे तरीके होते हैं जो फ़्रेमवर्क के लिए साफ़ तौर पर नहीं दिखते;
- हर डिवाइस, PJRT API को लागू करने पर फ़ोकस करता है. साथ ही, यह फ़्रेमवर्क के लिए पारदर्शी नहीं हो सकता.
PJRT, C API और C++ API, दोनों की सुविधा देता है. दोनों लेयर में प्लग इन करना ठीक है. C++ एपीआई, कुछ कॉन्सेप्ट को अलग करने के लिए क्लास का इस्तेमाल करता है. हालांकि, इसका संबंध XLA डेटाटाइप से भी है. इस पेज पर, C++ API के बारे में जानकारी दी गई है.
PJRT कॉम्पोनेंट
PjRtClient
पूरा रेफ़रंस pjrt_client.h > PjRtClient
पर देखें.
क्लाइंट, डिवाइस और फ़्रेमवर्क के बीच होने वाली सभी बातचीत को मैनेज करते हैं. साथ ही, बातचीत में इस्तेमाल होने वाली सभी स्थितियों को एन्कैप्सुलेट करते हैं. उनके पास PJRT प्लग इन के साथ इंटरैक्ट करने के लिए, एपीआई का एक सामान्य सेट होता है. साथ ही, उनके पास किसी प्लग इन के लिए डिवाइस और मेमोरी स्पेस का मालिकाना हक होता है.
PjRtDevice
पूरे रेफ़रंस pjrt_client.h > PjRtDevice
और pjrt_device_description.h
पर देखें
डिवाइस की कैटगरी का इस्तेमाल, किसी एक डिवाइस के बारे में बताने के लिए किया जाता है. किसी डिवाइस का ब्यौरा, डिवाइस के टाइप (जैसे जीपीयू/सीपीयू/xPU) की पहचान करने के लिए होता है. साथ ही, इसमें डिवाइस के ग्रिड में जगह की जानकारी भी शामिल होती है.
डिवाइसों को यह भी पता होता है कि उनसे कौनसे मेमोरी स्पेस जुड़े हैं और उनका मालिकाना हक किस क्लाइंट के पास है.
ज़रूरी नहीं है कि किसी डिवाइस को उससे जुड़े असल डेटा के बफ़र के बारे में पता हो. हालांकि, वह इससे जुड़ी मेमोरी स्पेस को देखकर यह पता लगा सकता है.
PjRtMemorySpace
पूरा रेफ़रंस pjrt_client.h > PjRtMemorySpace
पर देखें.
'यादें' स्पेस का इस्तेमाल, यादों की जगह के बारे में बताने के लिए किया जा सकता है. इन्हें अनपिन किया जा सकता है और इन्हें किसी भी डिवाइस पर ऐक्सेस किया जा सकता है. इसके अलावा, इन्हें पिन भी किया जा सकता है और इन्हें किसी खास डिवाइस पर ऐक्सेस किया जा सकता है.
मेमोरी स्पेस में, डेटा के बफ़र और उन डिवाइसों (बहुवचन) के बारे में जानकारी होती है जिनसे मेमोरी स्पेस जुड़ा होता है. साथ ही, उस क्लाइंट के बारे में भी जानकारी होती है जिसका वह हिस्सा होता है.
PjRtBuffer
पूरी जानकारी के लिए pjrt_client.h > PjRtBuffer
पर जाएं.
बफ़र, डिवाइस पर डेटा को ऐसे फ़ॉर्मैट में रखता है जिससे प्लग इन के साथ आसानी से काम किया जा सके. जैसे, MLIR एलिमेंट attr या मालिकाना हक वाला टेंसर फ़ॉर्मैट.
फ़्रेमवर्क, डिवाइस की मेमोरी में xla::Literal
के तौर पर डेटा भेजने की कोशिश कर सकता है. जैसे, मॉड्यूल के इनपुट आर्ग्युमेंट के लिए, जिसे डिवाइस की मेमोरी में क्लोन किया जाना या उधार लिया जाना चाहिए. जब बफ़र की ज़रूरत नहीं होती, तो फ़्रेमवर्क Delete
तरीके को शुद्धिकरण के लिए इस्तेमाल करता है.
बफ़र को उस मेमोरी स्पेस की जानकारी होती है जिसका वह हिस्सा होता है. साथ ही, यह पता लगाया जा सकता है कि कौनसे डिवाइस उसे ऐक्सेस कर सकते हैं. हालांकि, यह ज़रूरी नहीं है कि बफ़र को अपने डिवाइसों की जानकारी हो.
फ़्रेमवर्क के साथ कम्यूनिकेट करने के लिए, बफ़र को पता होता है कि उन्हें xla::Literal
टाइप में कैसे बदलना है और कैसे:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
बफ़र बनाने के एपीआई में बफ़र सीमैंटिक होता है. इसकी मदद से यह तय किया जा सकता है कि होस्ट बफ़र के लिटरल डेटा को शेयर, कॉपी या उसमें बदलाव किया जा सकता है या नहीं.
आखिर में, अगर बफ़र को फ़्रेमवर्क लेयर x = jit(foo)(10)
में किसी वैरिएबल को असाइन किया गया है, तो हो सकता है कि उसे एक्सीक्यूट करने के दायरे से ज़्यादा समय तक बनाए रखने की ज़रूरत पड़े. इन मामलों में, बफ़र बाहरी रेफ़रंस बनाने की अनुमति देते हैं. ये रेफ़रंस, बफ़र में मौजूद डेटा के लिए, कुछ समय के लिए मालिकाना हक वाला पॉइंटर देते हैं. साथ ही, इनमें डेटा को समझने के लिए मेटाडेटा (dtype / dim sizes) भी शामिल होता है.
PjRtCompiler
पूरा रेफ़रंस pjrt_compiler.h > PjRtCompiler
पर देखें.
PjRtCompiler
क्लास, XLA बैकएंड के लिए लागू करने से जुड़ी ज़रूरी जानकारी देती है. हालांकि, प्लग इन को लागू करने के लिए, इसकी ज़रूरत नहीं होती. सिद्धांत के तौर पर, PjRtCompiler
या PjRtClient::Compile
तरीके की ज़िम्मेदारी इनपुट मॉड्यूल और PjRtLoadedExecutable
देना है.
PjRtExecutable / PjRtLoadedExecutable
पूरा रेफ़रंस pjrt_executable.h > PjRtExecutable
और pjrt_client.h > PjRtLoadedExecutable
पर देखें.
PjRtExecutable
, किसी कंपाइल किए गए आर्टफ़ैक्ट और उसे चलाने के विकल्पों को क्रम से लगाने/अक्रम से हटाने का तरीका जानता है, ताकि ज़रूरत के हिसाब से किसी एक्सीक्यूटेबल को सेव और लोड किया जा सके.
PjRtLoadedExecutable
, मेमोरी में कंपाइल किया गया एक ऐसा प्रोग्राम है जो इनपुट आर्ग्युमेंट को एक्ज़ीक्यूट करने के लिए तैयार है. यह PjRtExecutable
का सबक्लास है.
एक्ज़ीक्यूटेबल को क्लाइंट के Execute
वाले तरीकों में से किसी एक के ज़रिए इंटरफ़ेस के ज़रिए दिखाया जाता है:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Execute
को कॉल करने से पहले, फ़्रेमवर्क सभी ज़रूरी डेटा को लागू करने वाले क्लाइंट के मालिकाना हक वाले PjRtBuffers
को ट्रांसफ़र कर देगा. हालांकि, इसे फ़्रेमवर्क के रेफ़रंस के तौर पर इस्तेमाल किया जाएगा. इसके बाद, इन बफ़र को Execute
तरीके के लिए आर्ग्युमेंट के तौर पर दिया जाता है.
PJRT कॉन्सेप्ट
PjRtFutures और Async कंप्यूटेशन
अगर प्लग इन का कोई हिस्सा असिंक्रोनस तरीके से लागू किया जाता है, तो उसे फ़्यूचर को सही तरीके से लागू करना ज़रूरी है.
इस प्रोग्राम पर ध्यान दें:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
एक असाइन किए गए प्लग इन, कैलकुलेशन x
को सूची में जोड़ सकता है और तुरंत एक बफ़र दिखा सकता है, जो अभी पढ़ने के लिए तैयार नहीं है. हालांकि, इसे लागू करने पर यह बफ़र भर जाएगा. एक्ज़ीक्यूशन, x
के बाद ज़रूरी कंप्यूटेशन को क्यू में करना जारी रख सकता है. इसके लिए, x
की ज़रूरत नहीं होती. इसमें, अन्य PJRT डिवाइसों पर इसे एक्ज़ीक्यूट करना भी शामिल है. x
की वैल्यू की ज़रूरत पड़ने पर, एक्ज़ीक्यूशन तब तक ब्लॉक रहेगा, जब तक कि बफ़र GetReadyFuture
से रिटर्न की गई फ़्यूचर वैल्यू के ज़रिए खुद को तैयार नहीं बता देता.
फ़्यूचर का इस्तेमाल करके यह पता लगाया जा सकता है कि कोई ऑब्जेक्ट कब उपलब्ध होगा. इसमें डिवाइस और बफ़र भी शामिल हैं.
बेहतर कॉन्सेप्ट
बेस एपीआई को लागू करने के अलावा, JAX की ऐसी सुविधाएं भी बढ़ जाएंगी जिनका इस्तेमाल प्लगिन किया जा सकता है. ये सभी सुविधाएं ऑप्ट-इन की सुविधाएं हैं. इसका मतलब है कि आम तौर पर, JIT और एक्सीक्यूट वर्कफ़्लो इनके बिना काम करेगा. हालांकि, प्रोडक्शन क्वालिटी वाली पाइपलाइन के लिए, PJRT API के साथ काम करने वाली इनमें से किसी भी सुविधा के लिए, सहायता की डिग्री पर कुछ विचार किया जाना चाहिए:
- यादें सेव करने के लिए स्पेस
- मनपसंद लेआउट
- भेजने/पाने जैसे कम्यूनिकेशन ऑपरेशन
- होस्ट को ऑफ़लोड करना
- डेटा का बंटवारा
PJRT फ़्रेमवर्क-डिवाइस के बीच आम तौर पर होने वाला कम्यूनिकेशन
उदाहरण लॉग
PJRT प्लगिन को लोड करने और y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
को लागू करने के लिए कॉल किए गए तरीकों का लॉग नीचे दिया गया है. इस मामले में, हम StableHLO Reference PJRT प्लग इन के साथ इंटरैक्ट करने वाले JAX को लॉग करते हैं.
लॉग का उदाहरण
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)