পটভূমি
PJRT হল ইউনিফর্ম ডিভাইস API যা আমরা ML ইকোসিস্টেমে যোগ করতে চাই। দীর্ঘমেয়াদী দৃষ্টিভঙ্গি হল:
- ফ্রেমওয়ার্ক (JAX, TF, ইত্যাদি) PJRT কে কল করবে, যার ডিভাইস-নির্দিষ্ট বাস্তবায়ন রয়েছে যা ফ্রেমওয়ার্কের জন্য অস্বচ্ছ;
- প্রতিটি ডিভাইস PJRT API বাস্তবায়নে ফোকাস করে এবং ফ্রেমওয়ার্কের জন্য অস্বচ্ছ হতে পারে।
PJRT একটি C API এবং C++ API উভয়ই অফার করে। উভয় স্তরে প্লাগ ইন করা ঠিক আছে, C++ API কিছু ধারণাকে বিমূর্ত করার জন্য ক্লাস ব্যবহার করে, তবে XLA ডেটাটাইপগুলির সাথে আরও শক্তিশালী সম্পর্ক রয়েছে। এই পৃষ্ঠাটি C++ API-এ ফোকাস করে।
PJRT উপাদান
PjRtClient
pjrt_client.h > PjRtClient
এ সম্পূর্ণ রেফারেন্স।
ক্লায়েন্ট ডিভাইস এবং ফ্রেমওয়ার্কের মধ্যে সমস্ত যোগাযোগ পরিচালনা করে এবং যোগাযোগে ব্যবহৃত সমস্ত অবস্থাকে এনক্যাপসুলেট করে। একটি PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করার জন্য তাদের কাছে API-এর একটি সাধারণ সেট রয়েছে এবং তারা একটি প্রদত্ত প্লাগইনের জন্য ডিভাইস এবং মেমরি স্পেসগুলির মালিক।
PjRtDevice
pjrt_client.h > PjRtDevice
, এবং pjrt_device_description.h
এ সম্পূর্ণ রেফারেন্স
একটি ডিভাইস ক্লাস একটি একক ডিভাইস বর্ণনা করতে ব্যবহৃত হয়। একটি ডিভাইসের একটি ডিভাইসের বিবরণ রয়েছে যাতে এটির ধরনের (GPU/CPU/xPU শনাক্ত করার জন্য অনন্য হ্যাশ), এবং স্থানীয়ভাবে এবং বিশ্বব্যাপী ডিভাইসের একটি গ্রিডের মধ্যে অবস্থান।
ডিভাইসগুলি তাদের সম্পর্কিত মেমরি স্পেস এবং এটির মালিকানাধীন ক্লায়েন্টও জানে।
একটি ডিভাইস অগত্যা এটির সাথে সম্পর্কিত প্রকৃত ডেটার বাফারগুলি জানে না , তবে এটি তার সম্পর্কিত মেমরি স্পেসগুলি দেখে তা বের করতে পারে।
PjRtMemorySpace
pjrt_client.h > PjRtMemorySpace
এ সম্পূর্ণ রেফারেন্স।
মেমরির স্থান বর্ণনা করতে মেমরি স্পেস ব্যবহার করা যেতে পারে। এগুলি হয় আনপিন করা যেতে পারে, এবং যে কোনও জায়গায় বসবাসের জন্য বিনামূল্যে তবে একটি ডিভাইস থেকে অ্যাক্সেসযোগ্য, অথবা সেগুলি পিন করা যেতে পারে এবং একটি নির্দিষ্ট ডিভাইসে থাকতে হবে৷
মেমরি স্পেসগুলি ডেটার তাদের যুক্ত বাফারগুলি এবং ডিভাইসগুলি (বহুবচন) যেগুলির সাথে একটি মেমরি স্পেস যুক্ত, সেইসাথে এটি যে ক্লায়েন্টের একটি অংশ তা জানে৷
PjRtBuffer
pjrt_client.h > PjRtBuffer
এ সম্পূর্ণ রেফারেন্স।
একটি বাফার কিছু বিন্যাসে একটি ডিভাইসে ডেটা ধারণ করে যা প্লাগইনের ভিতরে কাজ করা সহজ হবে, যেমন একটি MLIR উপাদান attr বা একটি মালিকানাধীন টেনসর বিন্যাস। একটি ফ্রেমওয়ার্ক একটি xla::Literal
আকারে একটি ডিভাইসে ডেটা পাঠানোর চেষ্টা করতে পারে, যেমন মডিউলটিতে একটি ইনপুট আর্গুমেন্টের জন্য, যা অবশ্যই ডিভাইসের মেমরিতে ক্লোন (বা ধার করা) হতে হবে। একবার বাফারের আর প্রয়োজন হয় না পরিষ্কার করার জন্য ফ্রেমওয়ার্ক দ্বারা Delete
পদ্ধতিটি আহ্বান করা হয়।
একটি বাফার জানে যে মেমরি স্পেস এটির একটি অংশ, এবং ট্রানজিটিভভাবে বুঝতে পারে কোন ডিভাইসগুলি এটি অ্যাক্সেস করতে সক্ষম, কিন্তু বাফাররা তাদের ডিভাইসগুলি অগত্যা জানে না।
ফ্রেমওয়ার্কের সাথে যোগাযোগের জন্য, বাফাররা জানে কিভাবে xla::Literal
টাইপ থেকে এবং থেকে রূপান্তর করতে হয়:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
বাফার তৈরির জন্য এপিআইগুলিতে বাফার শব্দার্থবিদ্যা রয়েছে যা হোস্ট বাফার থেকে আক্ষরিক ডেটা ভাগ করা বা অনুলিপি করা বা পরিবর্তন করা যেতে পারে কিনা তা নির্দেশ করতে সহায়তা করে।
সবশেষে, একটি বাফারের কার্য সম্পাদনের সুযোগের চেয়ে বেশি সময় স্থায়ী হতে পারে, যদি এটি ফ্রেমওয়ার্ক লেয়ার x = jit(foo)(10)
একটি ভেরিয়েবলের জন্য নির্ধারিত হয়, এই ক্ষেত্রে বাফারগুলি বহিরাগত রেফারেন্স তৈরি করার অনুমতি দেয় যা একটি অস্থায়ী মালিকানাধীন পয়েন্টার প্রদান করে। অন্তর্নিহিত ডেটা ব্যাখ্যা করার জন্য মেটাডেটা (dtype/dim sizes) সহ বাফার দ্বারা ধারণ করা ডেটাতে।
PjRtকম্পাইলার
pjrt_compiler.h > PjRtCompiler
এ সম্পূর্ণ রেফারেন্স।
PjRtCompiler
ক্লাস XLA ব্যাকএন্ডের জন্য কার্যকর বাস্তবায়নের বিশদ প্রদান করে, কিন্তু একটি প্লাগইন বাস্তবায়নের জন্য প্রয়োজনীয় নয়। তত্ত্বগতভাবে, একটি PjRtCompiler
, বা PjRtClient::Compile
পদ্ধতির দায়িত্ব হল একটি ইনপুট মডিউল নেওয়া এবং একটি PjRtLoadedExecutable
ফেরত দেওয়া।
PjRtExecutable/PjRtLoadedExecutable
pjrt_executable.h > PjRtExecutable
, এবং pjrt_client.h > PjRtLoadedExecutable
এ সম্পূর্ণ রেফারেন্স।
একটি PjRtExecutable
জানে কিভাবে একটি সংকলিত শিল্পকর্ম এবং সম্পাদনের বিকল্পগুলি নিতে হয় এবং সেগুলিকে সিরিয়ালাইজ/ডিসারিয়ালাইজ করতে হয় যাতে একটি এক্সিকিউটেবল সংরক্ষণ করা যায় এবং প্রয়োজন অনুসারে লোড করা যায়।
PjRtLoadedExecutable
হল ইন-মেমরি কম্পাইল করা এক্সিকিউটেবল যা কার্যকর করার জন্য ইনপুট আর্গুমেন্টের জন্য প্রস্তুত, এটি PjRtExecutable
এর একটি সাবক্লাস।
এক্সিকিউটেবলগুলিকে ক্লায়েন্টের Execute
পদ্ধতিগুলির একটির মাধ্যমে ইন্টারফেস করা হয়:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Execute
কল করার আগে ফ্রেমওয়ার্কটি এক্সিকিউটিং ক্লায়েন্টের মালিকানাধীন PjRtBuffers
এ সমস্ত প্রয়োজনীয় ডেটা স্থানান্তর করবে, কিন্তু ফ্রেমওয়ার্কের রেফারেন্সের জন্য ফিরে আসবে। এই বাফারগুলি তারপর Execute
পদ্ধতিতে আর্গুমেন্ট হিসাবে সরবরাহ করা হয়।
PJRT ধারণা
PjRtFutures এবং Async কম্পিউটেশন
যদি একটি প্লাগইনের কোনো অংশ অ্যাসিঙ্ক্রোনাসভাবে প্রয়োগ করা হয়, তাহলে এটি অবশ্যই সঠিকভাবে ফিউচার বাস্তবায়ন করবে।
নিম্নলিখিত প্রোগ্রাম বিবেচনা করুন:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
একটি async প্লাগইন গণনা x
সারিবদ্ধ করতে সক্ষম হবে এবং অবিলম্বে একটি বাফার ফেরত দেবে যা এখনও পড়ার জন্য প্রস্তুত নয়, তবে সম্পাদন এটিকে পপুলেট করবে। এক্সিকিউশন x
পরে প্রয়োজনীয় কম্পিউটেশন সারিবদ্ধ করা চালিয়ে যেতে পারে, যার জন্য x
প্রয়োজন হয় না, অন্যান্য PJRT ডিভাইসে এক্সিকিউশন সহ। একবার x
এর মান প্রয়োজন হলে, যতক্ষণ না বাফারটি GetReadyFuture
দ্বারা প্রত্যাবর্তিত ভবিষ্যতের মাধ্যমে নিজেকে প্রস্তুত বলে ঘোষণা না করে ততক্ষণ পর্যন্ত এক্সিকিউশন ব্লক করা হবে।
ডিভাইস এবং বাফার সহ কোন বস্তু কখন উপলব্ধ হবে তা নির্ধারণ করতে ফিউচার উপযোগী হতে পারে।
উন্নত ধারণা
বেস এপিআইগুলি বাস্তবায়নের বাইরে প্রসারিত করা JAX এর বৈশিষ্ট্যগুলিকে প্রসারিত করবে যা একটি প্লাগইন দ্বারা ব্যবহার করা যেতে পারে। এগুলি সমস্তই অপ্ট-ইন বৈশিষ্ট্য এই অর্থে যে সাধারণ JIT এ এবং কার্যপ্রবাহ কার্যকর করবে এগুলি ছাড়াই কাজ করবে, তবে একটি উত্পাদন মানের পাইপলাইনের জন্য কিছু চিন্তাভাবনা সম্ভবত PJRT API দ্বারা সমর্থিত এই বৈশিষ্ট্যগুলির যে কোনওটির জন্য সমর্থনের ডিগ্রিতে রাখা উচিত:
- মেমরি স্পেস
- কাস্টম লেআউট
- সেন্ড/রিকভির মত যোগাযোগ ব্যবস্থা
- হোস্ট অফলোডিং
- শেয়ারিং
সাধারণ PJRT ফ্রেমওয়ার্ক-ডিভাইস যোগাযোগ
উদাহরণ লগ
PJRT প্লাগইন লোড করতে এবং y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
চালানোর জন্য নিম্নলিখিত পদ্ধতিগুলির একটি লগ দেওয়া হল। এই ক্ষেত্রে আমরা StableHLO রেফারেন্স PJRT প্লাগইনের সাথে ইন্টারঅ্যাক্ট করে JAX লগ করি।
উদাহরণ লগ
//////////////////////////////////
// Load the plugin
//////////////////////////////////
I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////
I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// ...
return %3 : tensor<i32>
}
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630
//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)