PJRT C++ Device API की खास जानकारी

बैकग्राउंड

PJRT एक यूनिफ़ॉर्म डिवाइस एपीआई है, जिसे हमें एमएल (मशीन लर्निंग) नेटवर्क के साथ जोड़ना है. लंबे समय के लिए, हमारा मकसद यह है:

  1. फ़्रेमवर्क (JAX, TF वगैरह) PJRT को कॉल करेंगे. इसमें डिवाइस के हिसाब से लागू किए गए ऐसे तरीके होते हैं जो फ़्रेमवर्क के लिए साफ़ तौर पर नहीं दिखते;
  2. हर डिवाइस, PJRT API को लागू करने पर फ़ोकस करता है. साथ ही, यह फ़्रेमवर्क के लिए पारदर्शी नहीं हो सकता.

PJRT, C API और C++ API, दोनों की सुविधा देता है. दोनों लेयर में प्लग इन करना ठीक है. C++ एपीआई, कुछ कॉन्सेप्ट को अलग करने के लिए क्लास का इस्तेमाल करता है. हालांकि, इसका संबंध XLA डेटाटाइप से भी है. इस पेज पर, C++ एपीआई के बारे में बताया गया है.

PJRT कॉम्पोनेंट

PJRT कॉम्पोनेंट

PjRtClient

पूरा रेफ़रंस pjrt_client.h > PjRtClient पर देखें.

क्लाइंट, डिवाइस और फ़्रेमवर्क के बीच होने वाली सभी बातचीत को मैनेज करते हैं. साथ ही, बातचीत में इस्तेमाल होने वाली सभी स्थितियों को एन्कैप्सुलेट करते हैं. उनके पास PJRT प्लग इन के साथ इंटरैक्ट करने के लिए, एपीआई का एक सामान्य सेट होता है. साथ ही, उनके पास किसी प्लग इन के लिए डिवाइस और मेमोरी स्पेस का मालिकाना हक होता है.

PjRtDevice

पूरे रेफ़रंस pjrt_client.h > PjRtDevice और pjrt_device_description.h पर देखें

डिवाइस क्लास का इस्तेमाल, किसी एक डिवाइस के बारे में बताने के लिए किया जाता है. किसी डिवाइस की जानकारी में, उसके टाइप (GPU/CPU/xPU की पहचान करने के लिए यूनीक हैश) और डिवाइसों के ग्रिड में उसकी जगह की जानकारी होती है. यह जानकारी, स्थानीय और दुनिया भर में, दोनों जगहों पर उपलब्ध होती है.

डिवाइसों को यह भी पता होता है कि उनसे कौनसे मेमोरी स्पेस जुड़े हैं और उनका मालिकाना हक किस क्लाइंट के पास है.

ज़रूरी नहीं है कि किसी डिवाइस को उससे जुड़े असल डेटा के बफ़र के बारे में पता हो. हालांकि, वह इससे जुड़ी मेमोरी स्पेस को देखकर यह पता लगा सकता है.

PjRtMemorySpace

पूरा रेफ़रंस pjrt_client.h > PjRtMemorySpace पर देखें.

'यादें' स्पेस का इस्तेमाल, यादों की जगह के बारे में बताने के लिए किया जा सकता है. इन्हें अनपिन किया जा सकता है और इन्हें किसी भी डिवाइस पर ऐक्सेस किया जा सकता है. इसके अलावा, इन्हें पिन भी किया जा सकता है और इन्हें किसी खास डिवाइस पर ऐक्सेस किया जा सकता है.

मेमोरी स्पेस में, डेटा के बफ़र और उन डिवाइसों (बहुवचन) के बारे में जानकारी होती है जिनसे मेमोरी स्पेस जुड़ा होता है. साथ ही, उस क्लाइंट के बारे में भी जानकारी होती है जिसका वह हिस्सा होता है.

PjRtBuffer

पूरा रेफ़रंस pjrt_client.h > PjRtBuffer पर देखें.

बफ़र, डिवाइस पर डेटा को किसी ऐसे फ़ॉर्मैट में सेव करता है जिससे प्लग इन में काम करना आसान हो. जैसे, MLIR एलिमेंट एट्रिब्यूट या मालिकाना टेंसर फ़ॉर्मैट. कोई फ़्रेमवर्क, xla::Literal के तौर पर किसी डिवाइस पर डेटा भेजने की कोशिश कर सकता है. उदाहरण के लिए, मॉड्यूल के इनपुट आर्ग्युमेंट के लिए, जिसे डिवाइस की मेमोरी में क्लोन (या उधार) किया जाना चाहिए. जब बफ़र की ज़रूरत नहीं होती, तो फ़्रेमवर्क Delete तरीके को शुद्धिकरण के लिए इस्तेमाल करता है.

बफ़र को उस मेमोरी स्पेस की जानकारी होती है जिसका वह हिस्सा होता है. साथ ही, यह पता लगाया जा सकता है कि कौनसे डिवाइस उसे ऐक्सेस कर सकते हैं. हालांकि, यह ज़रूरी नहीं है कि बफ़र को अपने डिवाइसों की जानकारी हो.

फ़्रेमवर्क के साथ कम्यूनिकेट करने के लिए, बफ़र को xla::Literal टाइप में बदलने और उससे बदलने का तरीका पता होता है:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

बफ़र बनाने के लिए एपीआई में बफ़र सेमेटिक्स होता है. इससे यह तय करने में मदद मिलती है कि होस्ट बफ़र का लिटरल डेटा शेयर किया जा सकता है या कॉपी किया जा सकता है या उसमें बदलाव किया जा सकता है.

आखिर में, अगर बफ़र को फ़्रेमवर्क लेयर x = jit(foo)(10) में किसी वैरिएबल को असाइन किया गया है, तो हो सकता है कि उसे एक्सीक्यूट करने के दायरे से ज़्यादा समय तक बनाए रखने की ज़रूरत पड़े. इन मामलों में, बफ़र बाहरी रेफ़रंस बनाने की अनुमति देते हैं. ये रेफ़रंस, बफ़र में मौजूद डेटा के लिए, कुछ समय के लिए मालिकाना हक वाला पॉइंटर देते हैं. साथ ही, इनमें डेटा को समझने के लिए मेटाडेटा (dtype / dim sizes) भी शामिल होता है.

PjRtCompiler

पूरा रेफ़रंस pjrt_compiler.h > PjRtCompiler पर देखें.

PjRtCompiler क्लास, XLA बैकएंड के लिए लागू करने से जुड़ी ज़रूरी जानकारी देती है. हालांकि, प्लग इन को लागू करने के लिए, इसकी ज़रूरत नहीं होती. सिद्धांत रूप से, PjRtCompiler या PjRtClient::Compile तरीके की ज़िम्मेदारी, इनपुट मॉड्यूल लेकर PjRtLoadedExecutable दिखाना है.

PjRtExecutable / PjRtLoadedExecutable

पूरा रेफ़रंस pjrt_executable.h > PjRtExecutable और pjrt_client.h > PjRtLoadedExecutable पर देखें.

PjRtExecutable, किसी कंपाइल किए गए आर्टफ़ैक्ट और उसे चलाने के विकल्पों को क्रम से लगाने/अक्रम से हटाने का तरीका जानता है, ताकि ज़रूरत के हिसाब से किसी एक्सीक्यूटेबल को सेव और लोड किया जा सके.

PjRtLoadedExecutable, मेमोरी में कंपाइल किया गया एक ऐसा प्रोग्राम है जो इनपुट आर्ग्युमेंट को एक्ज़ीक्यूट करने के लिए तैयार है. यह PjRtExecutable का सबक्लास है.

क्लाइंट के Execute तरीकों में से किसी एक के ज़रिए, एक्सीक्यूटेबल को इंटरफ़ेस किया जाता है:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute को कॉल करने से पहले, फ़्रेमवर्क ज़रूरी डेटा को PjRtBuffers पर ट्रांसफ़र कर देगा. PjRtBuffers का मालिकाना हक, उस क्लाइंट के पास होता है जिसने फ़्रेमवर्क को चालू किया है. हालांकि, फ़्रेमवर्क के लिए यह डेटा, रेफ़रंस के तौर पर वापस लाया जाता है. इसके बाद, इन बफ़र को Execute तरीके के लिए आर्ग्युमेंट के तौर पर दिया जाता है.

PJRT कॉन्सेप्ट

PjRtFutures और एक साथ काम न करने वाले कैलकुलेशन

अगर प्लग इन का कोई हिस्सा असिंक्रोनस तरीके से लागू किया जाता है, तो उसे फ़्यूचर को सही तरीके से लागू करना ज़रूरी है.

इस प्रोग्राम पर ध्यान दें:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

एक असाइन किए गए प्लग इन, कैलकुलेशन x को सूची में जोड़ सकता है और तुरंत एक बफ़र दिखा सकता है, जो अभी पढ़ने के लिए तैयार नहीं है. हालांकि, इसे लागू करने पर यह बफ़र भर जाएगा. x के बाद भी, ज़रूरी कैलकुलेशन को सूची में जोड़ा जा सकता है. इसके लिए, x की ज़रूरत नहीं होती. इसमें, PJRT के दूसरे डिवाइसों पर कैलकुलेशन को लागू करना भी शामिल है. x की वैल्यू की ज़रूरत पड़ने पर, एक्ज़ीक्यूशन तब तक ब्लॉक रहेगा, जब तक कि बफ़र GetReadyFuture से रिटर्न की गई फ़्यूचर वैल्यू के ज़रिए खुद को तैयार नहीं बता देता.

फ़्यूचर का इस्तेमाल करके यह पता लगाया जा सकता है कि कोई ऑब्जेक्ट कब उपलब्ध होगा. इसमें डिवाइस और बफ़र भी शामिल हैं.

बेहतर कॉन्सेप्ट

बुनियादी एपीआई लागू करने के अलावा, JAX की उन सुविधाओं को भी इस्तेमाल किया जा सकता है जिनका इस्तेमाल प्लगिन के ज़रिए किया जा सकता है. ये सभी सुविधाएं ऑप्ट-इन की सुविधाएं हैं. इसका मतलब है कि आम तौर पर, JIT और एक्सीक्यूट वर्कफ़्लो इनके बिना काम करेगा. हालांकि, प्रोडक्शन क्वालिटी वाली पाइपलाइन के लिए, PJRT API के साथ काम करने वाली इनमें से किसी भी सुविधा के लिए, सहायता की डिग्री पर कुछ विचार किया जाना चाहिए:

  • यादें सेव करने के लिए स्पेस
  • मनपसंद लेआउट
  • भेजने/पाने जैसे कम्यूनिकेशन ऑपरेशन
  • होस्ट को ऑफ़लोड करना
  • डेटा का बंटवारा

PJRT फ़्रेमवर्क-डिवाइस के बीच आम तौर पर होने वाला कम्यूनिकेशन

लॉग का उदाहरण

यहां PJRT प्लग इन को लोड करने और y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) को लागू करने के लिए, इस्तेमाल किए गए तरीकों का लॉग दिया गया है. इस मामले में, हम StableHLO Reference PJRT प्लग इन के साथ इंटरैक्ट करने वाले JAX को लॉग करते हैं.

लॉग का उदाहरण

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)