نمای کلی API دستگاه PJRT C++

پیشینه

PJRT یک رابط برنامه‌نویسی یکپارچه برای دستگاه‌ها است که می‌خواهیم به اکوسیستم یادگیری ماشین اضافه کنیم. چشم‌انداز بلندمدت این است که:

  1. چارچوب‌ها (JAX، TF و غیره) PJRT را فراخوانی می‌کنند که پیاده‌سازی‌های مخصوص دستگاه دارد که برای چارچوب‌ها مبهم هستند؛
  2. هر دستگاه بر پیاده‌سازی APIهای PJRT تمرکز دارد و می‌تواند برای چارچوب‌ها مبهم باشد.

PJRT هم API مربوط به C و هم API مربوط به C++ را ارائه می‌دهد. اتصال به هر دو لایه اشکالی ندارد، API مربوط به C++ از کلاس‌ها برای انتزاعی کردن برخی مفاهیم استفاده می‌کند، اما پیوندهای قوی‌تری با انواع داده‌های XLA نیز دارد. این صفحه بر روی API مربوط به C++ تمرکز دارد.

قطعات PJRT

قطعات PJRT

PjRtClient

مرجع کامل در pjrt_client.h > PjRtClient .

کلاینت‌ها تمام ارتباطات بین دستگاه و چارچوب را مدیریت می‌کنند و تمام حالت‌های مورد استفاده در این ارتباطات را کپسوله‌سازی می‌کنند. آن‌ها مجموعه‌ای عمومی از APIها را برای تعامل با یک افزونه PJRT دارند و مالک دستگاه‌ها و فضاهای حافظه برای یک افزونه مشخص هستند.

PjRtDevice

منابع کامل در pjrt_client.h > PjRtDevice و pjrt_device_description.h

کلاس دستگاه برای توصیف یک دستگاه واحد استفاده می‌شود. هر دستگاه دارای یک توصیف دستگاه است که به شناسایی نوع آن (هش منحصر به فرد برای شناسایی GPU/CPU/xPU) و موقعیت آن در شبکه‌ای از دستگاه‌ها، چه به صورت محلی و چه به صورت جهانی، کمک می‌کند.

دستگاه‌ها همچنین فضاهای حافظه مرتبط با خود و کلاینتی که متعلق به آن است را می‌شناسند.

یک دستگاه لزوماً بافرهای داده‌های واقعی مرتبط با خود را نمی‌شناسد ، اما می‌تواند با بررسی فضاهای حافظه مرتبط با آن، این موضوع را تشخیص دهد.

فضای حافظه PjRt

مرجع کامل در pjrt_client.h > PjRtMemorySpace .

فضاهای حافظه می‌توانند برای توصیف مکانی از حافظه استفاده شوند. این فضاها می‌توانند بدون پین باشند و آزادانه در هر جایی قرار بگیرند اما از یک دستگاه قابل دسترسی باشند، یا می‌توانند پین شده باشند و باید روی یک دستگاه خاص قرار داشته باشند.

فضاهای حافظه، بافرهای داده مرتبط با خود و دستگاه‌هایی (جمع) که یک فضای حافظه به آنها مرتبط است و همچنین کلاینتی که بخشی از آن است را می‌شناسند.

PjRtBuffer

مرجع کامل در pjrt_client.h > PjRtBuffer .

یک بافر، داده‌ها را روی یک دستگاه در قالبی نگه می‌دارد که کار با آن در داخل افزونه آسان باشد، مانند عناصر MLIR attr یا یک قالب تانسور اختصاصی. یک چارچوب ممکن است سعی کند داده‌ها را به شکل xla::Literal به یک دستگاه ارسال کند، یعنی برای یک آرگومان ورودی به ماژول، که باید کلون (یا قرض گرفته شود) به حافظه دستگاه باشد. هنگامی که دیگر به بافر نیازی نباشد، روش Delete توسط چارچوب برای پاکسازی فراخوانی می‌شود.

یک بافر فضای حافظه‌ای که بخشی از آن است را می‌شناسد و به صورت انتقالی می‌تواند تشخیص دهد که کدام دستگاه‌ها قادر به دسترسی به آن هستند، اما بافرها لزوماً دستگاه‌های خود را نمی‌شناسند.

برای ارتباط با چارچوب‌ها، بافرها می‌دانند که چگونه به یک نوع xla::Literal تبدیل کنند و از آن برگردند:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

APIهای ایجاد بافر دارای Buffer Semantics هستند که به تعیین اینکه آیا داده‌های تحت‌اللفظی از بافر میزبان می‌توانند به اشتراک گذاشته شوند یا کپی شوند یا تغییر داده شوند، کمک می‌کنند.

در نهایت، اگر یک بافر به متغیری در لایه چارچوب x = jit(foo)(10) اختصاص داده شود، ممکن است به ماندگاری بیشتری نسبت به محدوده اجرای خود نیاز داشته باشد. در این موارد، بافرها امکان ایجاد ارجاعات خارجی را فراهم می‌کنند که یک اشاره‌گر موقت به داده‌های نگهداری شده توسط بافر، همراه با فراداده (dtype / dim sizes) برای تفسیر داده‌های زیربنایی ارائه می‌دهند.

کامپایلر PjRt

مرجع کامل در pjrt_compiler.h > PjRtCompiler .

کلاس PjRtCompiler جزئیات پیاده‌سازی مفیدی را برای بک‌اندهای XLA ارائه می‌دهد، اما برای پیاده‌سازی یک افزونه ضروری نیست. در تئوری، مسئولیت یک PjRtCompiler یا متد PjRtClient::Compile ، دریافت یک ماژول ورودی و بازگرداندن یک PjRtLoadedExecutable است.

PjRtExecutable / PjRtLoadedExecutable

مرجع کامل در pjrt_executable.h > PjRtExecutable و pjrt_client.h > PjRtLoadedExecutable .

یک PjRtExecutable می‌داند که چگونه یک مصنوع کامپایل شده و گزینه‌های اجرایی را دریافت کرده و آنها را سریالیزه/غیر سریالیزه کند تا یک فایل اجرایی بتواند در صورت نیاز ذخیره و بارگذاری شود.

PjRtLoadedExecutable یک فایل اجرایی کامپایل شده در حافظه است که آماده اجرا شدن توسط آرگومان‌های ورودی است و زیرکلاسی از PjRtExecutable محسوب می‌شود.

فایل‌های اجرایی از طریق یکی از متدهای Execute کلاینت با آنها ارتباط برقرار می‌کنند:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

قبل از فراخوانی Execute چارچوب تمام داده‌های مورد نیاز را به PjRtBuffers متعلق به کلاینت در حال اجرا منتقل می‌کند، اما برای ارجاع به چارچوب بازگردانده می‌شود. سپس این بافرها به عنوان آرگومان به متد Execute ارائه می‌شوند.

مفاهیم PJRT

معاملات آتی و محاسبات ناهمزمان

اگر هر بخشی از یک افزونه به صورت غیرهمزمان پیاده‌سازی شود، باید به درستی آینده‌ها را پیاده‌سازی کند.

برنامه زیر را در نظر بگیرید:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

یک افزونه‌ی ناهمگام می‌تواند محاسبه‌ی x را در صف قرار دهد و بلافاصله بافری را برگرداند که هنوز آماده‌ی خواندن نیست، اما اجرا آن را پر می‌کند. اجرا می‌تواند به صف کردن محاسبات لازم پس از x که به x نیاز ندارند، از جمله اجرا روی سایر دستگاه‌های PJRT، ادامه دهد. هنگامی که مقدار x مورد نیاز باشد، اجرا مسدود می‌شود تا زمانی که بافر از طریق آینده‌ی برگردانده شده توسط GetReadyFuture خود را آماده اعلام کند.

آینده‌ها می‌توانند برای تعیین زمان در دسترس قرار گرفتن یک شیء، از جمله دستگاه‌ها و بافرها، مفید باشند.

مفاهیم پیشرفته

فراتر رفتن از پیاده‌سازی APIهای پایه، ویژگی‌های JAX را که می‌توانند توسط یک افزونه استفاده شوند، گسترش می‌دهد. همه اینها ویژگی‌های اختیاری هستند به این معنا که در گردش کار معمول JIT و اجرا بدون آنها نیز کار می‌کنند، اما برای یک خط تولید با کیفیت، احتمالاً باید در مورد میزان پشتیبانی از هر یک از این ویژگی‌های پشتیبانی شده توسط APIهای PJRT فکر شود:

  • فضاهای حافظه
  • طرح‌بندی‌های سفارشی
  • عملیات ارتباطی مانند ارسال/دریافت
  • تخلیه بار میزبان
  • شاردینگ

ارتباط معمول چارچوب-دستگاه PJRT

مثال گزارش

در ادامه، لاگی از متدهایی که برای بارگذاری افزونه PJRT و اجرای y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) فراخوانی می‌شوند، آمده است. در این حالت، ما JAX را در تعامل با افزونه StableHLO Reference PJRT لاگ می‌کنیم.

مثال لاگ

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)