PJRT C++ Device API की खास जानकारी

बैकग्राउंड

PJRT एक जैसा डिवाइस एपीआई है. हमें इसे एमएल के इकोसिस्टम में जोड़ना है. हमारा लॉन्ग टर्म विज़न यह है कि:

  1. फ़्रेमवर्क (JAX, TF वगैरह) PJRT को कॉल करेंगे. इसमें डिवाइस के हिसाब से लागू होने वाली ऐसी सुविधाएं होती हैं जिनके बारे में फ़्रेमवर्क को पता नहीं होता;
  2. हर डिवाइस, PJRT API लागू करने पर फ़ोकस करता है. साथ ही, यह फ़्रेमवर्क के लिए अपारदर्शी हो सकता है.

PJRT, C API और C++ API, दोनों की सुविधा देता है. किसी भी लेयर में प्लग इन किया जा सकता है. C++ API, कुछ कॉन्सेप्ट को अलग करने के लिए क्लास का इस्तेमाल करता है. हालांकि, यह XLA डेटा टाइप से भी ज़्यादा जुड़ा हुआ है. इस पेज पर, C++ API के बारे में जानकारी दी गई है.

PJRT कॉम्पोनेंट

PJRT कॉम्पोनेंट

PjRtClient

पूरा रेफ़रंस pjrt_client.h > PjRtClient पर देखें.

क्लाइंट, डिवाइस और फ़्रेमवर्क के बीच होने वाले सभी कम्यूनिकेशन को मैनेज करते हैं. साथ ही, कम्यूनिकेशन में इस्तेमाल होने वाली सभी स्थितियों को शामिल करते हैं. इनके पास PJRT प्लगिन के साथ इंटरैक्ट करने के लिए, एपीआई का एक सामान्य सेट होता है. साथ ही, इनके पास किसी प्लगिन के लिए डिवाइसों और मेमोरी स्पेस का मालिकाना हक होता है.

PjRtDevice

पूरे रेफ़रंस pjrt_client.h > PjRtDevice और pjrt_device_description.h पर देखें

डिवाइस क्लास का इस्तेमाल, किसी एक डिवाइस के बारे में बताने के लिए किया जाता है. किसी डिवाइस में डिवाइस का ब्यौरा होता है. इससे यह पता चलता है कि डिवाइस किस तरह का है (जीपीयू/सीपीयू/एक्सपीयू की पहचान करने के लिए यूनीक हैश). साथ ही, इससे डिवाइसों के ग्रिड में डिवाइस की जगह की जानकारी मिलती है. यह जानकारी स्थानीय और वैश्विक, दोनों स्तरों पर मिलती है.

डिवाइसों को यह भी पता होता है कि वे किस मेमोरी स्पेस से जुड़े हैं और उनका मालिकाना हक किसके पास है.

किसी डिवाइस को उससे जुड़े असल डेटा के बफ़र के बारे में ज़रूरी नहीं है कि पता हो. हालांकि, वह उससे जुड़े मेमोरी स्पेस को देखकर इसका पता लगा सकता है.

PjRtMemorySpace

पूरा रेफ़रंस pjrt_client.h > PjRtMemorySpace पर देखें.

मेमोरी स्पेस का इस्तेमाल, मेमोरी की जगह की जानकारी देने के लिए किया जा सकता है. इन्हें अनपिन किया जा सकता है. साथ ही, ये किसी भी डिवाइस पर उपलब्ध हो सकते हैं. हालांकि, इन्हें किसी डिवाइस से ऐक्सेस किया जा सकता है. इसके अलावा, इन्हें पिन भी किया जा सकता है और ये किसी खास डिवाइस पर उपलब्ध होने चाहिए.

मेमोरी स्पेस को, डेटा के अपने बफ़र और उन डिवाइसों के बारे में पता होता है जिनसे मेमोरी स्पेस जुड़ा होता है. साथ ही, उसे उस क्लाइंट के बारे में भी पता होता है जिसका वह हिस्सा होता है.

PjRtBuffer

पूरा रेफ़रंस pjrt_client.h > PjRtBuffer पर देखें.

बफ़र, किसी डिवाइस पर डेटा को ऐसे फ़ॉर्मैट में सेव करता है जिसे प्लगिन में आसानी से इस्तेमाल किया जा सकता है. जैसे, MLIR एलिमेंट एट्रिब्यूट या मालिकाना हक वाला टेंसर फ़ॉर्मैट. कोई फ़्रेमवर्क, किसी डिवाइस को xla::Literal के तौर पर डेटा भेज सकता है.जैसे, मॉड्यूल के इनपुट आर्ग्युमेंट के लिए. इसे डिवाइस की मेमोरी में क्लोन (या उधार लिया गया) किया जाना चाहिए. जब बफ़र की ज़रूरत नहीं होती, तब फ़्रेमवर्क Delete तरीके को लागू करता है, ताकि बफ़र को हटाया जा सके.

बफ़र को पता होता है कि वह मेमोरी के किस हिस्से का हिस्सा है. साथ ही, वह यह भी पता लगा सकता है कि कौनसे डिवाइस इसे ऐक्सेस कर सकते हैं. हालांकि, बफ़र को यह पता नहीं होता कि उसके डिवाइस कौनसे हैं.

फ़्रेमवर्क से कम्यूनिकेट करने के लिए, बफ़र को यह पता होता है कि xla::Literal टाइप में कैसे बदला जाए और इससे वापस कैसे बदला जाए:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

बफ़र बनाने के लिए इस्तेमाल किए जाने वाले एपीआई में बफ़र सिमैंटिक्स होते हैं. इनसे यह तय करने में मदद मिलती है कि होस्ट बफ़र से लिटरल डेटा शेयर किया जा सकता है, कॉपी किया जा सकता है या बदला जा सकता है.

आखिर में, अगर बफ़र को फ़्रेमवर्क लेयर x = jit(foo)(10) में किसी वैरिएबल को असाइन किया जाता है, तो हो सकता है कि बफ़र को अपने एक्ज़ीक्यूशन के स्कोप से ज़्यादा समय तक चलना पड़े. ऐसे मामलों में, बफ़र बाहरी रेफ़रंस बनाने की अनुमति देते हैं. ये रेफ़रंस, बफ़र में मौजूद डेटा के लिए कुछ समय के लिए मालिकाना हक वाला पॉइंटर उपलब्ध कराते हैं. साथ ही, बुनियादी डेटा को समझने के लिए मेटाडेटा (dtype / dim साइज़) भी उपलब्ध कराते हैं.

PjRtCompiler

पूरा रेफ़रंस pjrt_compiler.h > PjRtCompiler पर देखें.

PjRtCompiler क्लास, XLA बैकएंड के लिए लागू करने से जुड़ी ज़रूरी जानकारी देती है. हालांकि, किसी प्लगिन के लिए इसे लागू करना ज़रूरी नहीं है. सैद्धांतिक तौर पर, PjRtCompiler या PjRtClient::Compile तरीके की ज़िम्मेदारी यह होती है कि वह एक इनपुट मॉड्यूल ले और PjRtLoadedExecutable दिखाए.

PjRtExecutable / PjRtLoadedExecutable

पूरी जानकारी pjrt_executable.h > PjRtExecutable और pjrt_client.h > PjRtLoadedExecutable पर देखें.

PjRtExecutable को यह पता होता है कि कंपाइल किए गए आर्टफ़ैक्ट और एक्ज़ीक्यूशन के विकल्पों को कैसे लिया जाए. साथ ही, उन्हें कैसे क्रम से लगाया/हटाया जाए, ताकि एक्ज़ीक्यूटेबल को सेव किया जा सके और ज़रूरत के मुताबिक लोड किया जा सके.

PjRtLoadedExecutable, मेमोरी में कंपाइल किया गया ऐसा एक्ज़ीक्यूटेबल होता है जो इनपुट आर्ग्युमेंट को एक्ज़ीक्यूट करने के लिए तैयार होता है. यह PjRtExecutable का सबक्लास होता है.

एक्ज़ीक्यूटेबल, क्लाइंट के Execute के किसी एक तरीके के ज़रिए इंटरफ़ेस किए जाते हैं:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute को कॉल करने से पहले, फ़्रेमवर्क सभी ज़रूरी डेटा को PjRtBuffers पर ट्रांसफ़र कर देगा. हालांकि, यह डेटा फ़्रेमवर्क के पास वापस आ जाएगा, ताकि वह इसका रेफ़रंस दे सके. इसके बाद, इन बफ़र को Execute तरीके के लिए आर्ग्युमेंट के तौर पर उपलब्ध कराया जाता है.

PJRT के कॉन्सेप्ट

फ़्यूचर और एसिंक कंप्यूटेशन

अगर किसी प्लगिन का कोई हिस्सा एसिंक्रोनस तरीके से लागू किया जाता है, तो उसे फ़्यूचर को सही तरीके से लागू करना ज़रूरी है.

इस प्रोग्राम पर ध्यान दें:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

एसिंक प्लगिन, कंप्यूटेशन x को कतार में लगा सकता है. साथ ही, तुरंत एक ऐसा बफ़र दिखा सकता है जिसे अभी पढ़ा नहीं जा सकता. हालांकि, एक्ज़ीक्यूशन के बाद इसे पढ़ा जा सकेगा. x के बाद, ज़रूरी कैलकुलेशन को कतार में लगाने का काम जारी रखा जा सकता है. इसके लिए, x की ज़रूरत नहीं होती. इसमें अन्य PJRT डिवाइसों पर एक्ज़ीक्यूशन भी शामिल है. x की वैल्यू की ज़रूरत होने पर, एक्ज़ीक्यूशन तब तक ब्लॉक रहेगा, जब तक बफ़र, GetReadyFuture से मिले फ़्यूचर के ज़रिए खुद को तैयार नहीं कर लेता.

फ़्यूचर का इस्तेमाल यह तय करने के लिए किया जा सकता है कि कोई ऑब्जेक्ट कब उपलब्ध होगा. इसमें डिवाइस और बफ़र शामिल हैं.

बेहतर कॉन्सेप्ट

बुनियादी एपीआई लागू करने के अलावा, JAX की सुविधाओं को बढ़ाया जा सकता है. इससे प्लगिन के लिए उपलब्ध सुविधाओं का दायरा बढ़ जाएगा. ये सभी ऑप्ट-इन सुविधाएं हैं. इसका मतलब है कि सामान्य JIT और एक्ज़ीक्यूट वर्कफ़्लो, इनके बिना भी काम करेगा. हालांकि, प्रोडक्शन क्वालिटी वाली पाइपलाइन के लिए, PJRT API के साथ काम करने वाली इनमें से किसी भी सुविधा के लिए, सहायता के स्तर पर कुछ विचार किया जाना चाहिए:

  • मेमोरी स्पेस
  • मनपसंद लेआउट
  • कम्यूनिकेशन से जुड़ी कार्रवाइयां, जैसे कि भेजना/पाना
  • होस्ट ऑफ़लोडिंग
  • शार्डिंग

PJRT फ़्रेमवर्क और डिवाइस के बीच सामान्य कम्यूनिकेशन

लॉग का उदाहरण

यहां PJRT प्लगिन को लोड करने और y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) को एक्ज़ीक्यूट करने के लिए कॉल किए गए तरीकों का लॉग दिया गया है. इस मामले में, हम JAX को StableHLO Reference PJRT प्लगिन के साथ इंटरैक्ट करते हुए लॉग करते हैं.

लॉग का उदाहरण

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)