پس زمینه
PJRT یک Device API است که می خواهیم به اکوسیستم ML اضافه کنیم. چشم انداز بلند مدت این است که:
- فریمورکها (JAX، TF، و غیره) PJRT را فراخوانی میکنند، که دارای پیادهسازیهای خاص دستگاه است که نسبت به چارچوبها غیرشفاف هستند.
- هر دستگاه بر پیاده سازی API های PJRT تمرکز می کند و می تواند نسبت به فریم ورک ها غیر شفاف باشد.
PJRT هر دو C API و C++ API را ارائه می دهد. وصل کردن در هر یک از لایهها مشکلی ندارد، C++ API از کلاسها برای انتزاع برخی مفاهیم استفاده میکند، اما همچنین پیوندهای قویتری با انواع داده XLA دارد. این صفحه بر روی C++ API تمرکز دارد.
اجزای PJRT
PjRtClient
مرجع کامل در pjrt_client.h > PjRtClient
.
کلاینت ها تمام ارتباطات بین دستگاه و فریم ورک را مدیریت می کنند و تمام حالت های استفاده شده در ارتباط را در خود محصور می کنند. آنها مجموعهای از APIهای عمومی برای تعامل با یک پلاگین PJRT دارند و دستگاهها و فضای حافظه برای یک پلاگین معین را در اختیار دارند.
PjRtDevice
مراجع کامل در pjrt_client.h > PjRtDevice
و pjrt_device_description.h
یک کلاس دستگاه برای توصیف یک دستگاه واحد استفاده می شود. یک دستگاه دارای توضیحات دستگاه برای کمک به شناسایی نوع آن (هش منحصر به فرد برای شناسایی GPU/CPU/xPU) و مکان در شبکه ای از دستگاه ها به صورت محلی و جهانی است.
دستگاه ها همچنین فضای حافظه مرتبط خود و کلاینت متعلق به آن را می شناسند.
یک دستگاه لزوماً بافرهای داده های واقعی مرتبط با آن را نمی شناسد، اما می تواند با نگاه کردن به فضاهای حافظه مرتبط آن متوجه شود.
PjRtMemorySpace
مرجع کامل در pjrt_client.h > PjRtMemorySpace
.
از فضاهای حافظه می توان برای توصیف مکان حافظه استفاده کرد. اینها را میتوان برداشت و در هر جایی زندگی کرد اما از یک دستگاه قابل دسترسی است، یا میتوان آنها را پین کرد و باید روی دستگاه خاصی زندگی کنند.
فضاهای حافظه بافرهای مربوط به داده ها و دستگاه هایی (جمع) که فضای حافظه با آنها مرتبط است و همچنین کلاینت که بخشی از آن است را می شناسند.
PjRtBuffer
مرجع کامل در pjrt_client.h > PjRtBuffer
.
یک بافر داده ها را بر روی یک دستگاه در قالبی نگه می دارد که کار کردن با آن در داخل افزونه آسان است، مانند عناصر MLIR attr یا فرمت تانسور اختصاصی. یک چارچوب ممکن است سعی کند داده ها را به شکل xla::Literal
به یک دستگاه ارسال کند، یعنی برای یک آرگومان ورودی به ماژول، که باید شبیه سازی شود (یا قرض گرفته شود)، به حافظه دستگاه. هنگامی که دیگر به بافر نیاز نیست، متد Delete
توسط فریمورک برای پاکسازی فراخوانی می شود.
یک بافر فضای حافظه را که بخشی از آن است می داند و به طور گذرا می تواند تشخیص دهد که کدام دستگاه ها می توانند به آن دسترسی داشته باشند، اما بافرها لزوماً دستگاه های خود را نمی شناسند.
برای برقراری ارتباط با فریمورکها، بافرها میدانند که چگونه به xla::Literal
type تبدیل شوند:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
APIها برای ایجاد بافر دارای Buffer Semantics هستند که به تعیین اینکه آیا داده های تحت اللفظی از بافر میزبان می توانند به اشتراک گذاشته شوند یا کپی شوند یا جهش پیدا کنند، کمک می کند.
در نهایت، اگر بافر به متغیری در لایه چارچوب x = jit(foo)(10)
نسبت داده شود، ممکن است یک بافر بیشتر از محدوده اجرایش دوام بیاورد، در این موارد بافرها اجازه ساختن مراجع خارجی را می دهند که یک اشاره گر موقتی را ارائه می کنند. به داده هایی که توسط بافر نگهداری می شود، همراه با ابرداده (dtype / اندازه های کم نور) برای تفسیر داده های اساسی.
PjRtCompiler
مرجع کامل در pjrt_compiler.h > PjRtCompiler
.
کلاس PjRtCompiler
جزئیات پیاده سازی مفیدی را برای Backend های XLA ارائه می دهد، اما برای پیاده سازی افزونه ضروری نیست. در تئوری، مسئولیت یک PjRtCompiler
یا روش PjRtClient::Compile
، گرفتن یک ماژول ورودی و برگرداندن یک PjRtLoadedExecutable
است.
PjRtExecutable / PjRtLoadedExecutable
مرجع کامل در pjrt_executable.h > PjRtExecutable
و pjrt_client.h > PjRtLoadedExecutable
.
یک PjRtExecutable
می داند که چگونه یک مصنوع کامپایل شده و گزینه های اجرا را بگیرد و آنها را سریالی/جدایی کند تا یک فایل اجرایی بتواند در صورت نیاز ذخیره و بارگذاری شود.
PjRtLoadedExecutable
یک فایل اجرایی کامپایل شده در حافظه است که برای اجرای آرگومان های ورودی آماده است، این یک زیر کلاس از PjRtExecutable
است.
فایل های اجرایی از طریق یکی از متدهای Execute
مشتری با آنها ارتباط برقرار می کنند:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
قبل از فراخوانی Execute
فریم ورک تمام دادههای مورد نیاز را به PjRtBuffers
متعلق به کلاینت اجراکننده منتقل میکند، اما برای چارچوب به مرجع بازگردانده میشود. سپس این بافرها به عنوان آرگومان های متد Execute
ارائه می شوند.
مفاهیم PJRT
PjRtFutures & Async Computations
اگر هر بخشی از یک افزونه به صورت ناهمزمان پیاده سازی شود، باید آتی را به درستی پیاده سازی کند.
برنامه زیر را در نظر بگیرید:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
یک پلاگین غیر همگام میتواند محاسبات x
را در صف قرار دهد و فوراً بافری را که هنوز برای خواندن آماده نیست، بازگرداند، اما اجرا آن را پر میکند. اجرا میتواند محاسبات لازم را بعد از x
، که نیازی به x
ندارند، از جمله اجرا در سایر دستگاههای PJRT در صف قرار دهد. هنگامی که مقدار x
مورد نیاز باشد، اجرا تا زمانی که بافر خود را آماده اعلام کند از طریق آینده ای که توسط GetReadyFuture
برگردانده می شود مسدود می شود.
فیوچرها می توانند برای تعیین زمان در دسترس شدن یک شی، از جمله دستگاه ها و بافرها، مفید باشند.
مفاهیم پیشرفته
فراتر از اجرای API های پایه، ویژگی های JAX را که می تواند توسط یک افزونه استفاده شود، گسترش می دهد. همه اینها ویژگیهای انتخابی هستند به این معنا که در JIT معمولی و گردش کار بدون آنها کار میکند، اما برای یک خط لوله کیفیت تولید، احتمالاً باید در مورد درجه پشتیبانی از هر یک از این ویژگیها که توسط APIهای PJRT پشتیبانی میشود، فکر کرد:
- فضاهای حافظه
- طرح بندی های سفارشی
- عملیات ارتباطی مانند send/recv
- بارگذاری میزبان
- شاردینگ
ارتباط معمولی چارچوب-دستگاه PJRT
ثبت مثال
در زیر گزارشی از متدهایی است که برای بارگیری افزونه PJRT و اجرای y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
فراخوانی شده است. در این مورد ما JAX را در تعامل با پلاگین StableHLO Reference PJRT ثبت می کنیم.
گزارش مثال
//////////////////////////////////
// Load the plugin
//////////////////////////////////
I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////
I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// ...
return %3 : tensor<i32>
}
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630
//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)