PJRT C++ ডিভাইস API ওভারভিউ

পটভূমি

PJRT হল একটি অভিন্ন ডিভাইস API যা আমরা ML ইকোসিস্টেমে যোগ করতে চাই। দীর্ঘমেয়াদী দৃষ্টিভঙ্গি হল:

  1. ফ্রেমওয়ার্ক (JAX, TF, ইত্যাদি) PJRT কল করবে, যার ডিভাইস-নির্দিষ্ট বাস্তবায়ন রয়েছে যা ফ্রেমওয়ার্কের জন্য অস্বচ্ছ;
  2. প্রতিটি ডিভাইস PJRT API বাস্তবায়নের উপর দৃষ্টি নিবদ্ধ করে এবং ফ্রেমওয়ার্কের জন্য অস্বচ্ছ হতে পারে।

PJRT একটি C API এবং C++ API উভয়ই অফার করে। যেকোনো স্তরে প্লাগ ইন করা ঠিক আছে, C++ API কিছু ধারণাকে বিমূর্ত করার জন্য ক্লাস ব্যবহার করে, তবে XLA ডেটাটাইপের সাথেও এর শক্তিশালী সম্পর্ক রয়েছে। এই পৃষ্ঠাটি C++ API এর উপর আলোকপাত করে।

পিজেআরটি উপাদান

পিজেআরটি উপাদান

PjRtClient সম্পর্কে

সম্পূর্ণ রেফারেন্স pjrt_client.h > PjRtClient এ।

ক্লায়েন্টরা ডিভাইস এবং ফ্রেমওয়ার্কের মধ্যে সমস্ত যোগাযোগ পরিচালনা করে এবং যোগাযোগে ব্যবহৃত সমস্ত অবস্থা ধারণ করে। তাদের কাছে একটি PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করার জন্য API-এর একটি জেনেরিক সেট রয়েছে এবং তারা একটি প্রদত্ত প্লাগইনের জন্য ডিভাইস এবং মেমরি স্পেসের মালিক।

PjRtDevice সম্পর্কে

সম্পূর্ণ রেফারেন্স pjrt_client.h > PjRtDevice , এবং pjrt_device_description.h এ পাবেন।

একটি ডিভাইস বর্ণনা করার জন্য একটি ডিভাইস ক্লাস ব্যবহার করা হয়। একটি ডিভাইসের একটি ডিভাইসের বর্ণনা থাকে যা তার ধরণ (GPU/CPU/xPU সনাক্ত করার জন্য অনন্য হ্যাশ) এবং স্থানীয় এবং বিশ্বব্যাপী ডিভাইসের গ্রিডের মধ্যে অবস্থান সনাক্ত করতে সহায়তা করে।

ডিভাইসগুলি তাদের সংশ্লিষ্ট মেমরি স্পেস এবং এটি কোন ক্লায়েন্টের মালিকানাধীন তাও জানে।

একটি ডিভাইস অগত্যা তার সাথে সম্পর্কিত প্রকৃত ডেটার বাফারগুলি জানে না , তবে এটি তার সম্পর্কিত মেমরি স্পেসগুলি দেখে এটি বের করতে পারে।

PjRtMemorySpace সম্পর্কে

সম্পূর্ণ রেফারেন্স pjrt_client.h > PjRtMemorySpace এ।

মেমোরি স্পেস ব্যবহার করে মেমোরির অবস্থান বর্ণনা করা যেতে পারে। এগুলি হয় আনপিন করা যেতে পারে, এবং যেকোনো জায়গায় লাইভ করার জন্য বিনামূল্যে কিন্তু একটি ডিভাইস থেকে অ্যাক্সেসযোগ্য, অথবা এগুলি পিন করা যেতে পারে এবং একটি নির্দিষ্ট ডিভাইসে লাইভ থাকতে হবে।

মেমোরি স্পেসগুলি তাদের সম্পর্কিত ডেটা বাফার এবং একটি মেমোরি স্পেস যে ডিভাইসগুলির সাথে যুক্ত (বহুবচন), সেইসাথে এটি যে ক্লায়েন্টের অংশ তা জানে।

PjRtBuffer সম্পর্কে

সম্পূর্ণ রেফারেন্স pjrt_client.h > PjRtBuffer এ।

একটি বাফার ডিভাইসে এমন কিছু ফর্ম্যাটে ডেটা ধারণ করে যা প্লাগইনের ভিতরে কাজ করা সহজ হবে, যেমন একটি MLIR উপাদান attr বা একটি মালিকানাধীন টেনসর ফর্ম্যাট। একটি ফ্রেমওয়ার্ক একটি ডিভাইসে xla::Literal আকারে ডেটা পাঠানোর চেষ্টা করতে পারে, অর্থাৎ মডিউলে একটি ইনপুট আর্গুমেন্টের জন্য, যা অবশ্যই ডিভাইসের মেমরিতে ক্লোন (বা ধার) করতে হবে। একবার একটি বাফারের আর প্রয়োজন না হলে ফ্রেমওয়ার্ক পরিষ্কার করার জন্য Delete পদ্ধতিটি ব্যবহার করে।

একটি বাফার জানে যে এটি কোন মেমরি স্পেসের অংশ, এবং ট্রানজিটলি নির্ধারণ করতে পারে যে কোন ডিভাইসগুলি এটি অ্যাক্সেস করতে সক্ষম, কিন্তু বাফারগুলি অগত্যা তাদের ডিভাইসগুলি জানে না।

ফ্রেমওয়ার্কের সাথে যোগাযোগের জন্য, বাফাররা জানে কিভাবে xla::Literal টাইপে এবং থেকে রূপান্তর করতে হয়:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

বাফার তৈরির জন্য API গুলিতে বাফার সেমান্টিক্স থাকে যা হোস্ট বাফার থেকে আক্ষরিক ডেটা ভাগ করা, অনুলিপি করা বা রূপান্তর করা যেতে পারে কিনা তা নির্ধারণ করতে সহায়তা করে।

পরিশেষে, একটি বাফার তার কার্যকরীকরণের পরিধির চেয়ে বেশি সময় ধরে থাকতে পারে, যদি এটি ফ্রেমওয়ার্ক লেয়ার x = jit(foo)(10) একটি ভেরিয়েবলের জন্য নির্ধারিত হয়, এই ক্ষেত্রে বাফারগুলি বহিরাগত রেফারেন্স তৈরি করতে দেয় যা বাফার দ্বারা ধারণ করা ডেটার জন্য একটি অস্থায়ীভাবে মালিকানাধীন পয়েন্টার প্রদান করে, অন্তর্নিহিত ডেটা ব্যাখ্যা করার জন্য মেটাডেটা (dtype / dim আকার) সহ।

PjRtকম্পাইলার

সম্পূর্ণ রেফারেন্স pjrt_compiler.h > PjRtCompiler এ।

PjRtCompiler ক্লাসটি XLA ব্যাকএন্ডের জন্য কার্যকর বাস্তবায়নের বিবরণ প্রদান করে, কিন্তু একটি প্লাগইন বাস্তবায়নের জন্য এটি প্রয়োজনীয় নয়। তত্ত্ব অনুসারে, একটি PjRtCompiler , অথবা PjRtClient::Compile পদ্ধতির দায়িত্ব হল একটি ইনপুট মডিউল নেওয়া এবং একটি PjRtLoadedExecutable ফেরত দেওয়া।

PjRtExecutable / PjRtLoadedExecutable

সম্পূর্ণ রেফারেন্স pjrt_executable.h > PjRtExecutable , এবং pjrt_client.h > PjRtLoadedExecutable এ পাবেন।

একজন PjRtExecutable জানে কিভাবে একটি কম্পাইল করা আর্টিফ্যাক্ট এবং এক্সিকিউশন অপশন নিতে হয় এবং সেগুলিকে সিরিয়ালাইজ/ডিসিরিয়ালাইজ করতে হয় যাতে একটি এক্সিকিউটেবল প্রয়োজন অনুসারে সংরক্ষণ এবং লোড করা যায়।

PjRtLoadedExecutable হল ইন-মেমরি কম্পাইল করা এক্সিকিউটেবল যা ইনপুট আর্গুমেন্ট এক্সিকিউট করার জন্য প্রস্তুত, এটি PjRtExecutable এর একটি সাবক্লাস।

এক্সিকিউটেবলগুলি ক্লায়েন্টের Execute পদ্ধতিগুলির একটির মাধ্যমে ইন্টারফেস করা হয়:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute কল করার আগে ফ্রেমওয়ার্কটি সমস্ত প্রয়োজনীয় ডেটা এক্সিকিউটিভ ক্লায়েন্টের মালিকানাধীন PjRtBuffers এ স্থানান্তর করবে, কিন্তু ফ্রেমওয়ার্কটিকে রেফারেন্সের জন্য ফেরত পাঠানো হবে। এই বাফারগুলি তারপর Execute পদ্ধতিতে আর্গুমেন্ট হিসাবে সরবরাহ করা হয়।

পিজেআরটি ধারণা

ফিউচার এবং অ্যাসিঙ্ক গণনা

যদি প্লাগইনের কোনও অংশ অ্যাসিঙ্ক্রোনাসভাবে বাস্তবায়িত হয়, তবে এটিকে অবশ্যই ফিউচার সঠিকভাবে বাস্তবায়ন করতে হবে

নিম্নলিখিত প্রোগ্রামটি বিবেচনা করুন:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

একটি অ্যাসিঙ্ক প্লাগইন x কম্পিউটেশন এনকিউ করতে সক্ষম হবে এবং তাৎক্ষণিকভাবে এমন একটি বাফার ফিরিয়ে দেবে যা এখনও পড়ার জন্য প্রস্তুত নয়, কিন্তু এক্সিকিউশন এটিকে পূর্ণ করবে। এক্সিকিউশন x পরে প্রয়োজনীয় কম্পিউটেশন এনকিউ করতে পারে, যার জন্য x প্রয়োজন হয় না, অন্যান্য PJRT ডিভাইসে এক্সিকিউশন সহ। একবার x এর মান প্রয়োজন হলে, এক্সিকিউশন ব্লক হবে যতক্ষণ না বাফার GetReadyFuture দ্বারা ফেরত আসা ভবিষ্যতের মাধ্যমে নিজেকে প্রস্তুত ঘোষণা করে।

ডিভাইস এবং বাফার সহ, কোনও বস্তু কখন উপলব্ধ হবে তা নির্ধারণ করতে ফিউচার কার্যকর হতে পারে।

উন্নত ধারণা

বেস API গুলি বাস্তবায়নের বাইরে গেলে JAX এর বৈশিষ্ট্যগুলি প্রসারিত হবে যা একটি প্লাগইন দ্বারা ব্যবহার করা যেতে পারে। এগুলি সবই অপ্ট-ইন বৈশিষ্ট্য এই অর্থে যে সাধারণ JIT এবং এক্সিকিউট ওয়ার্কফ্লো এগুলি ছাড়াই কাজ করবে, তবে একটি উৎপাদন মানের পাইপলাইনের জন্য PJRT API গুলি দ্বারা সমর্থিত এই বৈশিষ্ট্যগুলির যেকোনো একটির জন্য সমর্থনের মাত্রা সম্পর্কে কিছু চিন্তাভাবনা করা উচিত:

  • মেমোরি স্পেস
  • কাস্টম লেআউট
  • যোগাযোগের বিকল্প যেমন send/recv
  • হোস্ট অফলোডিং
  • ভাগ করা

সাধারণ PJRT ফ্রেমওয়ার্ক-ডিভাইস যোগাযোগ

উদাহরণ লগ

PJRT প্লাগইন লোড করার এবং y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) কার্যকর করার জন্য যে পদ্ধতিগুলি বলা হয় তার একটি লগ নিচে দেওয়া হল। এই ক্ষেত্রে আমরা StableHLO রেফারেন্স PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করে JAX লগ করি।

উদাহরণ লগ

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)