نمای کلی API دستگاه PJRT C++

پس زمینه

PJRT یک Device API است که می خواهیم به اکوسیستم ML اضافه کنیم. چشم انداز بلند مدت این است که:

  1. فریم‌ورک‌ها (JAX، TF، و غیره) PJRT را فراخوانی می‌کنند، که دارای پیاده‌سازی‌های خاص دستگاه است که نسبت به چارچوب‌ها غیرشفاف هستند.
  2. هر دستگاه بر پیاده سازی API های PJRT تمرکز می کند و می تواند نسبت به فریم ورک ها غیر شفاف باشد.

PJRT هر دو C API و C++ API را ارائه می دهد. وصل کردن در هر یک از لایه‌ها مشکلی ندارد، C++ API از کلاس‌ها برای انتزاع برخی مفاهیم استفاده می‌کند، اما همچنین پیوندهای قوی‌تری با انواع داده XLA دارد. این صفحه بر روی C++ API تمرکز دارد.

اجزای PJRT

اجزای PJRT

PjRtClient

مرجع کامل در pjrt_client.h > PjRtClient .

کلاینت ها تمام ارتباطات بین دستگاه و فریم ورک را مدیریت می کنند و تمام حالت های استفاده شده در ارتباط را در خود محصور می کنند. آن‌ها مجموعه‌ای از APIهای عمومی برای تعامل با یک پلاگین PJRT دارند و دستگاه‌ها و فضای حافظه برای یک پلاگین معین را در اختیار دارند.

PjRtDevice

مراجع کامل در pjrt_client.h > PjRtDevice و pjrt_device_description.h

یک کلاس دستگاه برای توصیف یک دستگاه واحد استفاده می شود. یک دستگاه دارای توضیحات دستگاه برای کمک به شناسایی نوع آن (هش منحصر به فرد برای شناسایی GPU/CPU/xPU) و مکان در شبکه ای از دستگاه ها به صورت محلی و جهانی است.

دستگاه ها همچنین فضای حافظه مرتبط خود و کلاینت متعلق به آن را می شناسند.

یک دستگاه لزوماً بافرهای داده های واقعی مرتبط با آن را نمی شناسد، اما می تواند با نگاه کردن به فضاهای حافظه مرتبط آن متوجه شود.

PjRtMemorySpace

مرجع کامل در pjrt_client.h > PjRtMemorySpace .

از فضاهای حافظه می توان برای توصیف مکان حافظه استفاده کرد. این‌ها را می‌توان برداشت و در هر جایی زندگی کرد اما از یک دستگاه قابل دسترسی است، یا می‌توان آن‌ها را پین کرد و باید روی دستگاه خاصی زندگی کنند.

فضاهای حافظه بافرهای مربوط به داده ها و دستگاه هایی (جمع) که فضای حافظه با آنها مرتبط است و همچنین کلاینت که بخشی از آن است را می شناسند.

PjRtBuffer

مرجع کامل در pjrt_client.h > PjRtBuffer .

یک بافر داده ها را بر روی یک دستگاه در قالبی نگه می دارد که کار کردن با آن در داخل افزونه آسان است، مانند عناصر MLIR attr یا فرمت تانسور اختصاصی. یک چارچوب ممکن است سعی کند داده ها را به شکل xla::Literal به یک دستگاه ارسال کند، یعنی برای یک آرگومان ورودی به ماژول، که باید شبیه سازی شود (یا قرض گرفته شود)، به حافظه دستگاه. هنگامی که دیگر به بافر نیاز نیست، متد Delete توسط فریمورک برای پاکسازی فراخوانی می شود.

یک بافر فضای حافظه را که بخشی از آن است می داند و به طور گذرا می تواند تشخیص دهد که کدام دستگاه ها می توانند به آن دسترسی داشته باشند، اما بافرها لزوماً دستگاه های خود را نمی شناسند.

برای برقراری ارتباط با فریم‌ورک‌ها، بافرها می‌دانند که چگونه به xla::Literal type تبدیل شوند:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

APIها برای ایجاد بافر دارای Buffer Semantics هستند که به تعیین اینکه آیا داده های تحت اللفظی از بافر میزبان می توانند به اشتراک گذاشته شوند یا کپی شوند یا جهش پیدا کنند، کمک می کند.

در نهایت، اگر بافر به متغیری در لایه چارچوب x = jit(foo)(10) نسبت داده شود، ممکن است یک بافر بیشتر از محدوده اجرایش دوام بیاورد، در این موارد بافرها اجازه ساختن مراجع خارجی را می دهند که یک اشاره گر موقتی را ارائه می کنند. به داده هایی که توسط بافر نگهداری می شود، همراه با ابرداده (dtype / اندازه های کم نور) برای تفسیر داده های اساسی.

PjRtCompiler

مرجع کامل در pjrt_compiler.h > PjRtCompiler .

کلاس PjRtCompiler جزئیات پیاده سازی مفیدی را برای Backend های XLA ارائه می دهد، اما برای پیاده سازی افزونه ضروری نیست. در تئوری، مسئولیت یک PjRtCompiler یا روش PjRtClient::Compile ، گرفتن یک ماژول ورودی و برگرداندن یک PjRtLoadedExecutable است.

PjRtExecutable / PjRtLoadedExecutable

مرجع کامل در pjrt_executable.h > PjRtExecutable و pjrt_client.h > PjRtLoadedExecutable .

یک PjRtExecutable می داند که چگونه یک مصنوع کامپایل شده و گزینه های اجرا را بگیرد و آنها را سریالی/جدایی کند تا یک فایل اجرایی بتواند در صورت نیاز ذخیره و بارگذاری شود.

PjRtLoadedExecutable یک فایل اجرایی کامپایل شده در حافظه است که برای اجرای آرگومان های ورودی آماده است، این یک زیر کلاس از PjRtExecutable است.

فایل های اجرایی از طریق یکی از متدهای Execute مشتری با آنها ارتباط برقرار می کنند:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

قبل از فراخوانی Execute فریم ورک تمام داده‌های مورد نیاز را به PjRtBuffers متعلق به کلاینت اجراکننده منتقل می‌کند، اما برای چارچوب به مرجع بازگردانده می‌شود. سپس این بافرها به عنوان آرگومان های متد Execute ارائه می شوند.

مفاهیم PJRT

PjRtFutures & Async Computations

اگر هر بخشی از یک افزونه به صورت ناهمزمان پیاده سازی شود، باید آتی را به درستی پیاده سازی کند.

برنامه زیر را در نظر بگیرید:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

یک پلاگین غیر همگام می‌تواند محاسبات x را در صف قرار دهد و فوراً بافری را که هنوز برای خواندن آماده نیست، بازگرداند، اما اجرا آن را پر می‌کند. اجرا می‌تواند محاسبات لازم را بعد از x ، که نیازی به x ندارند، از جمله اجرا در سایر دستگاه‌های PJRT در صف قرار دهد. هنگامی که مقدار x مورد نیاز باشد، اجرا تا زمانی که بافر خود را آماده اعلام کند از طریق آینده ای که توسط GetReadyFuture برگردانده می شود مسدود می شود.

فیوچرها می توانند برای تعیین زمان در دسترس شدن یک شی، از جمله دستگاه ها و بافرها، مفید باشند.

مفاهیم پیشرفته

فراتر از اجرای API های پایه، ویژگی های JAX را که می تواند توسط یک افزونه استفاده شود، گسترش می دهد. همه اینها ویژگی‌های انتخابی هستند به این معنا که در JIT معمولی و گردش کار بدون آن‌ها کار می‌کند، اما برای یک خط لوله کیفیت تولید، احتمالاً باید در مورد درجه پشتیبانی از هر یک از این ویژگی‌ها که توسط APIهای PJRT پشتیبانی می‌شود، فکر کرد:

  • فضاهای حافظه
  • طرح بندی های سفارشی
  • عملیات ارتباطی مانند send/recv
  • بارگذاری میزبان
  • شاردینگ

ارتباط معمولی چارچوب-دستگاه PJRT

ثبت مثال

در زیر گزارشی از متدهایی است که برای بارگیری افزونه PJRT و اجرای y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) فراخوانی شده است. در این مورد ما JAX را در تعامل با پلاگین StableHLO Reference PJRT ثبت می کنیم.

گزارش مثال

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)