الخلفية
PJRT هي واجهة برمجة تطبيقات Device API الموحّدة التي نريد إضافتها إلى منظومة الذكاء الاصطناعي المتكاملة. تهدف الرؤية الطويلة المدى إلى:
- ستستدعي الإطارات (JAX وTF وما إلى ذلك) PJRT، التي تتضمّن تنفيذات خاصة بالأجهزة وغير شفافة للإطارات.
- يركز كل جهاز على تنفيذ واجهات برمجة تطبيقات PJRT، ويمكن أن يكون غير شفاف بالنسبة إلى الأُطر.
يوفّر PJRT واجهتَي برمجة تطبيقات C وC++. لا بأس بالربط في أي من الطبقتين، إذ تستخدم واجهة برمجة التطبيقات C++ فئات لتبسيط بعض المفاهيم، ولكنها ترتبط أيضًا بأنواع بيانات XLA بشكل أقوى. تركّز هذه الصفحة على واجهة برمجة التطبيقات C++.
مكوّنات PJRT
PjRtClient
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtClient
.
تدير التطبيقات جميع عمليات التواصل بين الجهاز والإطار، و تُغلِّف جميع الحالات المستخدَمة في عملية التواصل. لديهم مجموعة عامة من واجهات برمجة التطبيقات للتفاعل مع مكوّن إضافي من PJRT، وهم يملكون الأجهزة ومساحات الذاكرة لمكوّن إضافي معيّن.
PjRtDevice
المراجع الكاملة على pjrt_client.h > PjRtDevice
وpjrt_device_description.h
تُستخدَم فئة الجهاز لوصف جهاز واحد. يحتوي كل جهاز على وصف للمساعدة في تحديد نوعه (تجزئة فريدة لتحديد وحدة معالجة الرسومات/وحدة المعالجة المركزية/وحدة المعالجة المتعددة)، ومقعده ضمن شبكة من الأجهزة على مستوى الجهاز والمستوى العالمي.
وتعرف الأجهزة أيضًا مساحات الذاكرة المرتبطة بها والعميل الذي يملكها.
لا يعرف الجهاز بالضرورة ذاكرات التخزين المؤقت للبيانات الفعلية المرتبطة به، ولكن يمكنه معرفة ذلك من خلال البحث في مساحات الذاكرة المرتبطة به.
PjRtMemorySpace
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtMemorySpace
.
يمكن استخدام مساحات الذاكرة لوصف موقع ذاكرة. يمكن إما إلغاء تثبيت هذه التطبيقات، ونشرها في أي مكان، ولكن يمكن الوصول إليها من جهاز، أو تثبيتها على جهاز معيّن.
تعرف مساحات الذاكرة وحدات التخزين المؤقتة المرتبطة بها للبيانات والأجهزة (الجمع) التي تكون مساحة الذاكرة مرتبطة بها، بالإضافة إلى العميل الذي تكون جزءًا منه.
PjRtBuffer
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtBuffer
.
يخزّن المخزن المؤقت البيانات على جهاز بتنسيق يسهُل التعامل معه
داخل المكوّن الإضافي، مثل سمة عناصر MLIR أو تنسيق مصفوفة تعامُل خاصة.
قد يحاول إطار العمل إرسال البيانات إلى جهاز في شكل xla::Literal
،
أي لوسيطة إدخال إلى الوحدة، والتي يجب نسخها (أو استعارتها) لتحميلها إلى
ذاكرة الجهاز. بعد أن يصبح المخزن المؤقت غير مطلوب، يُستخدَم الإطار العملي لطريقة Delete
لتنظيفه.
يعرف المخزن المؤقت مساحة الذاكرة التي يشكّل جزءًا منها، ويمكنه بشكلٍ تبادلي معرفة الأجهزة التي يمكنها الوصول إليه، ولكن لا يعرف المخزن المؤقت بالضرورة أجهزته.
للتواصل مع إطارات العمل، تعرف الفواصل كيفية التحويل إلى نوع
xla::Literal
والعكس:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
تحتوي واجهات برمجة التطبيقات لإنشاء مخزن مؤقت على دلالات المخزن المؤقت التي تساعد في تحديد ما إذا كان يمكن مشاركة البيانات الحرفية من المخزن المؤقت للمضيف أو نسخها أو تغييرها.
أخيرًا، قد يحتاج المخزن المؤقت إلى الاستمرار لفترة أطول من نطاق تنفيذه، إذا تم تحديده
لمتغيّر في طبقة الإطار x = jit(foo)(10)
. وفي هذه الحالات، تسمح ملفّات التخزين المؤقت بإنشاء مراجع خارجية توفّر مؤشرًا يملكه مؤقتًا
للبيانات التي يحتفظ بها المخزن المؤقت، بالإضافة إلى البيانات الوصفية (أنواع البيانات / أحجام السمات)
لتفسير البيانات الأساسية.
PjRtCompiler
يمكنك الاطّلاع على المرجع الكامل على pjrt_compiler.h > PjRtCompiler
.
توفّر فئة PjRtCompiler
تفاصيل تنفيذ مفيدة لخلفيات XLA
، ولكنّها ليست ضرورية لتنفيذ المكوّن الإضافي. من الناحية النظرية، تتمثل
مسؤولية PjRtCompiler
أو طريقة PjRtClient::Compile
في
أخذ وحدة إدخال وإرجاع PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
المرجع الكامل على pjrt_executable.h > PjRtExecutable
وpjrt_client.h > PjRtLoadedExecutable
.
يعرف PjRtExecutable
كيفية أخذ عنصر مُجمَّع وخيارات التنفيذ
وتسلسلها أو عكس تسلسلها حتى يمكن تخزين ملف قابل للتنفيذ وتحميله
حسب الحاجة.
PjRtLoadedExecutable
هو ملف تنفيذي مجمَّع في الذاكرة وجاهز
لتشغيل مَعلمات الإدخال، وهو فئة فرعية من PjRtExecutable
.
يتم التفاعل مع الملفات التنفيذية من خلال إحدى طرق Execute
الخاصة بالعميل:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
قبل استدعاء Execute
، سينقل إطار العمل جميع البيانات المطلوبة إلى
PjRtBuffers
التي يملكها العميل الذي ينفذ الإجراء، ولكن يتم إرجاعها لإطار العمل كي تتم
مراجعتها. ويتم بعد ذلك تقديم هذه المخازن المؤقتة كوسيطات لطريقة Execute
.
مفاهيم PJRT
PjRtFutures والعمليات الحسابية غير المتزامنة
إذا تم تنفيذ أي جزء من المكوّن الإضافي بشكل غير متزامن، يجب أن يتم تنفيذ الوظائف المستقبلية بشكل صحيح.
فكِّر في البرنامج التالي:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
يمكن أن يضيف المكوّن الإضافي غير المتزامن عملية الحساب x
إلى "قائمة الانتظار"، ويعرض على الفور
مخزنًا مؤقتًا غير جاهز للقراءة بعد، ولكن سيؤدي التنفيذ إلى ملء
هذا المخزن المؤقت. يمكن أن يستمر التنفيذ في إضافة العمليات الحسابية اللازمة إلى "قائمة الانتظار" بعد x
، والتي
لا تتطلّب x
، بما في ذلك التنفيذ على أجهزة PJRT الأخرى. بعد الحاجة إلى قيمة
x
، سيتم حظر التنفيذ إلى أن يعلن المخزن المؤقت عن استعداده من خلال
المستقبل الذي يعرضه GetReadyFuture
.
يمكن أن تكون العناصر المستقبلية مفيدة لتحديد وقت توفّر عنصر معيّن، بما في ذلك الأجهزة ووحدات التخزين المؤقت.
المفاهيم المتقدّمة
سيؤدي تجاوز تنفيذ واجهات برمجة التطبيقات الأساسية إلى توسيع ميزات JAX التي يمكن استخدامها من خلال إحدى الإضافات. هذه الميزات كلها اختيارية، أي أنّه في سياق سير العمل العادي لـ JIT وتنفيذه، يمكن تنفيذهما بدونها، ولكن بالنسبة إلى مسار تدفق برمجي بجودة الإنتاج، من المحتمل أن يتم التفكير في درجة الدعم لأي من هذه الميزات المتوافقة مع واجهات برمجة التطبيقات PJRT:
- مساحات الذاكرة
- التنسيقات المخصصة
- عمليات التواصل، مثل الإرسال/الاستلام
- تفريغ الموارد على المضيف
- التقسيم إلى أجزاء
نموذج التواصل بين إطار عمل PJRT والجهاز
مثال على السجلّ
في ما يلي سجلّ بالطرق التي تمّ استدعاؤها لتحميل المكوّن الإضافي PJRT و ejecutang y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. في هذه الحالة،
نسجِّل تفاعل JAX مع المكوّن الإضافي PJRT في مرجع StableHLO.
مثال على السجلّ
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)