پیشینه
PJRT یک رابط برنامهنویسی یکپارچه برای دستگاهها است که میخواهیم به اکوسیستم یادگیری ماشین اضافه کنیم. چشمانداز بلندمدت این است که:
- چارچوبها (JAX، TF و غیره) PJRT را فراخوانی میکنند که پیادهسازیهای مخصوص دستگاه دارد که برای چارچوبها مبهم هستند؛
- هر دستگاه بر پیادهسازی APIهای PJRT تمرکز دارد و میتواند برای چارچوبها مبهم باشد.
PJRT هم API مربوط به C و هم API مربوط به C++ را ارائه میدهد. اتصال به هر دو لایه اشکالی ندارد، API مربوط به C++ از کلاسها برای انتزاعی کردن برخی مفاهیم استفاده میکند، اما پیوندهای قویتری با انواع دادههای XLA نیز دارد. این صفحه بر روی API مربوط به C++ تمرکز دارد.
قطعات PJRT
PjRtClient
مرجع کامل در pjrt_client.h > PjRtClient .
کلاینتها تمام ارتباطات بین دستگاه و چارچوب را مدیریت میکنند و تمام حالتهای مورد استفاده در این ارتباطات را کپسولهسازی میکنند. آنها مجموعهای عمومی از APIها را برای تعامل با یک افزونه PJRT دارند و مالک دستگاهها و فضاهای حافظه برای یک افزونه مشخص هستند.
PjRtDevice
منابع کامل در pjrt_client.h > PjRtDevice و pjrt_device_description.h
کلاس دستگاه برای توصیف یک دستگاه واحد استفاده میشود. هر دستگاه دارای یک توصیف دستگاه است که به شناسایی نوع آن (هش منحصر به فرد برای شناسایی GPU/CPU/xPU) و موقعیت آن در شبکهای از دستگاهها، چه به صورت محلی و چه به صورت جهانی، کمک میکند.
دستگاهها همچنین فضاهای حافظه مرتبط با خود و کلاینتی که متعلق به آن است را میشناسند.
یک دستگاه لزوماً بافرهای دادههای واقعی مرتبط با خود را نمیشناسد ، اما میتواند با بررسی فضاهای حافظه مرتبط با آن، این موضوع را تشخیص دهد.
فضای حافظه PjRt
مرجع کامل در pjrt_client.h > PjRtMemorySpace .
فضاهای حافظه میتوانند برای توصیف مکانی از حافظه استفاده شوند. این فضاها میتوانند بدون پین باشند و آزادانه در هر جایی قرار بگیرند اما از یک دستگاه قابل دسترسی باشند، یا میتوانند پین شده باشند و باید روی یک دستگاه خاص قرار داشته باشند.
فضاهای حافظه، بافرهای داده مرتبط با خود و دستگاههایی (جمع) که یک فضای حافظه به آنها مرتبط است و همچنین کلاینتی که بخشی از آن است را میشناسند.
PjRtBuffer
مرجع کامل در pjrt_client.h > PjRtBuffer .
یک بافر، دادهها را روی یک دستگاه در قالبی نگه میدارد که کار با آن در داخل افزونه آسان باشد، مانند عناصر MLIR attr یا یک قالب تانسور اختصاصی. یک چارچوب ممکن است سعی کند دادهها را به شکل xla::Literal به یک دستگاه ارسال کند، یعنی برای یک آرگومان ورودی به ماژول، که باید کلون (یا قرض گرفته شود) به حافظه دستگاه باشد. هنگامی که دیگر به بافر نیازی نباشد، روش Delete توسط چارچوب برای پاکسازی فراخوانی میشود.
یک بافر فضای حافظهای که بخشی از آن است را میشناسد و به صورت انتقالی میتواند تشخیص دهد که کدام دستگاهها قادر به دسترسی به آن هستند، اما بافرها لزوماً دستگاههای خود را نمیشناسند.
برای ارتباط با چارچوبها، بافرها میدانند که چگونه به یک نوع xla::Literal تبدیل کنند و از آن برگردند:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
APIهای ایجاد بافر دارای Buffer Semantics هستند که به تعیین اینکه آیا دادههای تحتاللفظی از بافر میزبان میتوانند به اشتراک گذاشته شوند یا کپی شوند یا تغییر داده شوند، کمک میکنند.
در نهایت، اگر یک بافر به متغیری در لایه چارچوب x = jit(foo)(10) اختصاص داده شود، ممکن است به ماندگاری بیشتری نسبت به محدوده اجرای خود نیاز داشته باشد. در این موارد، بافرها امکان ایجاد ارجاعات خارجی را فراهم میکنند که یک اشارهگر موقت به دادههای نگهداری شده توسط بافر، همراه با فراداده (dtype / dim sizes) برای تفسیر دادههای زیربنایی ارائه میدهند.
کامپایلر PjRt
مرجع کامل در pjrt_compiler.h > PjRtCompiler .
کلاس PjRtCompiler جزئیات پیادهسازی مفیدی را برای بکاندهای XLA ارائه میدهد، اما برای پیادهسازی یک افزونه ضروری نیست. در تئوری، مسئولیت یک PjRtCompiler یا متد PjRtClient::Compile ، دریافت یک ماژول ورودی و بازگرداندن یک PjRtLoadedExecutable است.
PjRtExecutable / PjRtLoadedExecutable
مرجع کامل در pjrt_executable.h > PjRtExecutable و pjrt_client.h > PjRtLoadedExecutable .
یک PjRtExecutable میداند که چگونه یک مصنوع کامپایل شده و گزینههای اجرایی را دریافت کرده و آنها را سریالیزه/غیر سریالیزه کند تا یک فایل اجرایی بتواند در صورت نیاز ذخیره و بارگذاری شود.
PjRtLoadedExecutable یک فایل اجرایی کامپایل شده در حافظه است که آماده اجرا شدن توسط آرگومانهای ورودی است و زیرکلاسی از PjRtExecutable محسوب میشود.
فایلهای اجرایی از طریق یکی از متدهای Execute کلاینت با آنها ارتباط برقرار میکنند:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
قبل از فراخوانی Execute چارچوب تمام دادههای مورد نیاز را به PjRtBuffers متعلق به کلاینت در حال اجرا منتقل میکند، اما برای ارجاع به چارچوب بازگردانده میشود. سپس این بافرها به عنوان آرگومان به متد Execute ارائه میشوند.
مفاهیم PJRT
معاملات آتی و محاسبات ناهمزمان
اگر هر بخشی از یک افزونه به صورت غیرهمزمان پیادهسازی شود، باید به درستی آیندهها را پیادهسازی کند.
برنامه زیر را در نظر بگیرید:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
یک افزونهی ناهمگام میتواند محاسبهی x را در صف قرار دهد و بلافاصله بافری را برگرداند که هنوز آمادهی خواندن نیست، اما اجرا آن را پر میکند. اجرا میتواند به صف کردن محاسبات لازم پس از x که به x نیاز ندارند، از جمله اجرا روی سایر دستگاههای PJRT، ادامه دهد. هنگامی که مقدار x مورد نیاز باشد، اجرا مسدود میشود تا زمانی که بافر از طریق آیندهی برگردانده شده توسط GetReadyFuture خود را آماده اعلام کند.
آیندهها میتوانند برای تعیین زمان در دسترس قرار گرفتن یک شیء، از جمله دستگاهها و بافرها، مفید باشند.
مفاهیم پیشرفته
فراتر رفتن از پیادهسازی APIهای پایه، ویژگیهای JAX را که میتوانند توسط یک افزونه استفاده شوند، گسترش میدهد. همه اینها ویژگیهای اختیاری هستند به این معنا که در گردش کار معمول JIT و اجرا بدون آنها نیز کار میکنند، اما برای یک خط تولید با کیفیت، احتمالاً باید در مورد میزان پشتیبانی از هر یک از این ویژگیهای پشتیبانی شده توسط APIهای PJRT فکر شود:
- فضاهای حافظه
- طرحبندیهای سفارشی
- عملیات ارتباطی مانند ارسال/دریافت
- تخلیه بار میزبان
- شاردینگ
ارتباط معمول چارچوب-دستگاه PJRT
مثال گزارش
در ادامه، لاگی از متدهایی که برای بارگذاری افزونه PJRT و اجرای y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) فراخوانی میشوند، آمده است. در این حالت، ما JAX را در تعامل با افزونه StableHLO Reference PJRT لاگ میکنیم.
مثال لاگ
//////////////////////////////////
// Load the plugin
//////////////////////////////////
I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////
I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// ...
return %3 : tensor<i32>
}
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630
//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)