PJRT C++ ডিভাইস API ওভারভিউ

পটভূমি

PJRT হল ইউনিফর্ম ডিভাইস API যা আমরা ML ইকোসিস্টেমে যোগ করতে চাই। দীর্ঘমেয়াদী দৃষ্টিভঙ্গি হল:

  1. ফ্রেমওয়ার্ক (JAX, TF, ইত্যাদি) PJRT কে কল করবে, যার ডিভাইস-নির্দিষ্ট বাস্তবায়ন রয়েছে যা ফ্রেমওয়ার্কের জন্য অস্বচ্ছ;
  2. প্রতিটি ডিভাইস PJRT API বাস্তবায়নে ফোকাস করে এবং ফ্রেমওয়ার্কের জন্য অস্বচ্ছ হতে পারে।

PJRT একটি C API এবং C++ API উভয়ই অফার করে। উভয় স্তরে প্লাগ ইন করা ঠিক আছে, C++ API কিছু ধারণাকে বিমূর্ত করার জন্য ক্লাস ব্যবহার করে, তবে XLA ডেটাটাইপগুলির সাথে আরও শক্তিশালী সম্পর্ক রয়েছে। এই পৃষ্ঠাটি C++ API-এ ফোকাস করে।

PJRT উপাদান

PJRT উপাদান

PjRtClient

pjrt_client.h > PjRtClient এ সম্পূর্ণ রেফারেন্স।

ক্লায়েন্ট ডিভাইস এবং ফ্রেমওয়ার্কের মধ্যে সমস্ত যোগাযোগ পরিচালনা করে এবং যোগাযোগে ব্যবহৃত সমস্ত অবস্থাকে এনক্যাপসুলেট করে। একটি PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করার জন্য তাদের কাছে API-এর একটি সাধারণ সেট রয়েছে এবং তারা একটি প্রদত্ত প্লাগইনের জন্য ডিভাইস এবং মেমরি স্পেসগুলির মালিক।

PjRtDevice

pjrt_client.h > PjRtDevice , এবং pjrt_device_description.h এ সম্পূর্ণ রেফারেন্স

একটি ডিভাইস ক্লাস একটি একক ডিভাইস বর্ণনা করতে ব্যবহৃত হয়। একটি ডিভাইসের একটি ডিভাইসের বিবরণ রয়েছে যাতে এটির ধরনের (GPU/CPU/xPU শনাক্ত করার জন্য অনন্য হ্যাশ), এবং স্থানীয়ভাবে এবং বিশ্বব্যাপী ডিভাইসের একটি গ্রিডের মধ্যে অবস্থান।

ডিভাইসগুলি তাদের সম্পর্কিত মেমরি স্পেস এবং এটির মালিকানাধীন ক্লায়েন্টও জানে।

একটি ডিভাইস অগত্যা এটির সাথে সম্পর্কিত প্রকৃত ডেটার বাফারগুলি জানে না , তবে এটি তার সম্পর্কিত মেমরি স্পেসগুলি দেখে তা বের করতে পারে।

PjRtMemorySpace

pjrt_client.h > PjRtMemorySpace এ সম্পূর্ণ রেফারেন্স।

মেমরির স্থান বর্ণনা করতে মেমরি স্পেস ব্যবহার করা যেতে পারে। এগুলি হয় আনপিন করা যেতে পারে, এবং যে কোনও জায়গায় বসবাসের জন্য বিনামূল্যে তবে একটি ডিভাইস থেকে অ্যাক্সেসযোগ্য, অথবা সেগুলি পিন করা যেতে পারে এবং একটি নির্দিষ্ট ডিভাইসে থাকতে হবে৷

মেমরি স্পেসগুলি ডেটার তাদের যুক্ত বাফারগুলি এবং ডিভাইসগুলি (বহুবচন) যেগুলির সাথে একটি মেমরি স্পেস যুক্ত, সেইসাথে এটি যে ক্লায়েন্টের একটি অংশ তা জানে৷

PjRtBuffer

pjrt_client.h > PjRtBuffer এ সম্পূর্ণ রেফারেন্স।

একটি বাফার কিছু বিন্যাসে একটি ডিভাইসে ডেটা ধারণ করে যা প্লাগইনের ভিতরে কাজ করা সহজ হবে, যেমন একটি MLIR উপাদান attr বা একটি মালিকানাধীন টেনসর বিন্যাস। একটি ফ্রেমওয়ার্ক একটি xla::Literal আকারে একটি ডিভাইসে ডেটা পাঠানোর চেষ্টা করতে পারে, যেমন মডিউলটিতে একটি ইনপুট আর্গুমেন্টের জন্য, যা অবশ্যই ডিভাইসের মেমরিতে ক্লোন (বা ধার করা) হতে হবে। একবার বাফারের আর প্রয়োজন হয় না পরিষ্কার করার জন্য ফ্রেমওয়ার্ক দ্বারা Delete পদ্ধতিটি আহ্বান করা হয়।

একটি বাফার জানে যে মেমরি স্পেস এটির একটি অংশ, এবং ট্রানজিটিভভাবে বুঝতে পারে কোন ডিভাইসগুলি এটি অ্যাক্সেস করতে সক্ষম, কিন্তু বাফাররা তাদের ডিভাইসগুলি অগত্যা জানে না।

ফ্রেমওয়ার্কের সাথে যোগাযোগের জন্য, বাফাররা জানে কিভাবে xla::Literal টাইপ থেকে এবং থেকে রূপান্তর করতে হয়:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

বাফার তৈরির জন্য এপিআইগুলিতে বাফার শব্দার্থবিদ্যা রয়েছে যা হোস্ট বাফার থেকে আক্ষরিক ডেটা ভাগ করা বা অনুলিপি করা বা পরিবর্তন করা যেতে পারে কিনা তা নির্দেশ করতে সহায়তা করে।

সবশেষে, একটি বাফারের কার্য সম্পাদনের সুযোগের চেয়ে বেশি সময় স্থায়ী হতে পারে, যদি এটি ফ্রেমওয়ার্ক লেয়ার x = jit(foo)(10) একটি ভেরিয়েবলের জন্য নির্ধারিত হয়, এই ক্ষেত্রে বাফারগুলি বহিরাগত রেফারেন্স তৈরি করার অনুমতি দেয় যা একটি অস্থায়ী মালিকানাধীন পয়েন্টার প্রদান করে। অন্তর্নিহিত ডেটা ব্যাখ্যা করার জন্য মেটাডেটা (dtype/dim sizes) সহ বাফার দ্বারা ধারণ করা ডেটাতে।

PjRtকম্পাইলার

pjrt_compiler.h > PjRtCompiler এ সম্পূর্ণ রেফারেন্স।

PjRtCompiler ক্লাস XLA ব্যাকএন্ডের জন্য কার্যকর বাস্তবায়নের বিশদ প্রদান করে, কিন্তু একটি প্লাগইন বাস্তবায়নের জন্য প্রয়োজনীয় নয়। তত্ত্বগতভাবে, একটি PjRtCompiler , বা PjRtClient::Compile পদ্ধতির দায়িত্ব হল একটি ইনপুট মডিউল নেওয়া এবং একটি PjRtLoadedExecutable ফেরত দেওয়া।

PjRtExecutable/PjRtLoadedExecutable

pjrt_executable.h > PjRtExecutable , এবং pjrt_client.h > PjRtLoadedExecutable এ সম্পূর্ণ রেফারেন্স।

একটি PjRtExecutable জানে কিভাবে একটি সংকলিত শিল্পকর্ম এবং সম্পাদনের বিকল্পগুলি নিতে হয় এবং সেগুলিকে সিরিয়ালাইজ/ডিসারিয়ালাইজ করতে হয় যাতে একটি এক্সিকিউটেবল সংরক্ষণ করা যায় এবং প্রয়োজন অনুসারে লোড করা যায়।

PjRtLoadedExecutable হল ইন-মেমরি কম্পাইল করা এক্সিকিউটেবল যা কার্যকর করার জন্য ইনপুট আর্গুমেন্টের জন্য প্রস্তুত, এটি PjRtExecutable এর একটি সাবক্লাস।

এক্সিকিউটেবলগুলিকে ক্লায়েন্টের Execute পদ্ধতিগুলির একটির মাধ্যমে ইন্টারফেস করা হয়:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute কল করার আগে ফ্রেমওয়ার্কটি এক্সিকিউটিং ক্লায়েন্টের মালিকানাধীন PjRtBuffers এ সমস্ত প্রয়োজনীয় ডেটা স্থানান্তর করবে, কিন্তু ফ্রেমওয়ার্কের রেফারেন্সের জন্য ফিরে আসবে। এই বাফারগুলি তারপর Execute পদ্ধতিতে আর্গুমেন্ট হিসাবে সরবরাহ করা হয়।

PJRT ধারণা

PjRtFutures এবং Async কম্পিউটেশন

যদি একটি প্লাগইনের কোনো অংশ অ্যাসিঙ্ক্রোনাসভাবে প্রয়োগ করা হয়, তাহলে এটি অবশ্যই সঠিকভাবে ফিউচার বাস্তবায়ন করবে।

নিম্নলিখিত প্রোগ্রাম বিবেচনা করুন:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

একটি async প্লাগইন গণনা x সারিবদ্ধ করতে সক্ষম হবে এবং অবিলম্বে একটি বাফার ফেরত দেবে যা এখনও পড়ার জন্য প্রস্তুত নয়, তবে সম্পাদন এটিকে পপুলেট করবে। এক্সিকিউশন x পরে প্রয়োজনীয় কম্পিউটেশন সারিবদ্ধ করা চালিয়ে যেতে পারে, যার জন্য x প্রয়োজন হয় না, অন্যান্য PJRT ডিভাইসে এক্সিকিউশন সহ। একবার x এর মান প্রয়োজন হলে, যতক্ষণ না বাফারটি GetReadyFuture দ্বারা প্রত্যাবর্তিত ভবিষ্যতের মাধ্যমে নিজেকে প্রস্তুত বলে ঘোষণা না করে ততক্ষণ পর্যন্ত এক্সিকিউশন ব্লক করা হবে।

ডিভাইস এবং বাফার সহ কোন বস্তু কখন উপলব্ধ হবে তা নির্ধারণ করতে ফিউচার উপযোগী হতে পারে।

উন্নত ধারণা

বেস এপিআইগুলি বাস্তবায়নের বাইরে প্রসারিত করা JAX এর বৈশিষ্ট্যগুলিকে প্রসারিত করবে যা একটি প্লাগইন দ্বারা ব্যবহার করা যেতে পারে। এগুলি সমস্তই অপ্ট-ইন বৈশিষ্ট্য এই অর্থে যে সাধারণ JIT এ এবং কার্যপ্রবাহ কার্যকর করবে এগুলি ছাড়াই কাজ করবে, তবে একটি উত্পাদন মানের পাইপলাইনের জন্য কিছু চিন্তাভাবনা সম্ভবত PJRT API দ্বারা সমর্থিত এই বৈশিষ্ট্যগুলির যে কোনওটির জন্য সমর্থনের ডিগ্রিতে রাখা উচিত:

  • মেমরি স্পেস
  • কাস্টম লেআউট
  • সেন্ড/রিকভির মত যোগাযোগ ব্যবস্থা
  • হোস্ট অফলোডিং
  • শেয়ারিং

সাধারণ PJRT ফ্রেমওয়ার্ক-ডিভাইস যোগাযোগ

উদাহরণ লগ

PJRT প্লাগইন লোড করতে এবং y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) চালানোর জন্য নিম্নলিখিত পদ্ধতিগুলির একটি লগ দেওয়া হল। এই ক্ষেত্রে আমরা StableHLO রেফারেন্স PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করে JAX লগ করি।

উদাহরণ লগ

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)