الخلفية
PJRT هي واجهة برمجة تطبيقات Device API الموحّدة التي نريد إضافتها إلى منظومة الذكاء الاصطناعي المتكاملة. تهدف الرؤية الطويلة المدى إلى:
- ستستدعي إطارات العمل (JAX، وTF، وغير ذلك) PJRT، الذي يتضمن عمليات تنفيذ خاصة بالجهاز لا تتوافق مع أطر العمل؛
- يركز كل جهاز على تنفيذ واجهات برمجة تطبيقات PJRT، ويمكن أن يكون غير شفاف بالنسبة إلى الأُطر.
يوفّر PJRT واجهتَي برمجة تطبيقات C وC++. لا بأس بالربط في أي من الطبقتين، إذ تستخدم واجهة برمجة التطبيقات C++ فئات لتبسيط بعض المفاهيم، ولكنها ترتبط أيضًا بأنواع بيانات XLA بشكل أقوى. تركّز هذه الصفحة على واجهة برمجة التطبيقات C++.
مكوّنات PJRT
PjRtClient
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtClient
.
يدير العملاء جميع الاتصالات بين الجهاز وإطار العمل، ويلخصون جميع الحالات المستخدمة في الاتصال. لديهم مجموعة عامة من واجهات برمجة التطبيقات للتفاعل مع مكوّن إضافي من PJRT، وهم يملكون الأجهزة ومساحات الذاكرة لمكوّن إضافي معيّن.
PjRtDevice
المراجع الكاملة على pjrt_client.h > PjRtDevice
وpjrt_device_description.h
تُستخدَم فئة الجهاز لوصف جهاز واحد. يحتوي كل جهاز على وصف للمساعدة في تحديد نوعه (تجزئة فريدة لتحديد وحدة معالجة الرسومات/وحدة المعالجة المركزية/وحدة المعالجة المتعددة)، ومقعده ضمن شبكة من الأجهزة على مستوى الجهاز والمستوى العالمي.
وتعرف الأجهزة أيضًا مساحات الذاكرة المرتبطة بها والعميل الذي يملكها.
لا يعرف الجهاز بالضرورة ذاكرات التخزين المؤقت للبيانات الفعلية المرتبطة به، ولكن يمكنه معرفة ذلك من خلال البحث في مساحات الذاكرة المرتبطة به.
PjRtMemorySpace
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtMemorySpace
.
يمكن استخدام مساحات الذاكرة لوصف موقع ذاكرة. يمكن إما إلغاء تثبيت هذه التطبيقات، ونشرها في أي مكان، ولكن يمكن الوصول إليها من جهاز، أو تثبيتها على جهاز معيّن.
تعرف مساحات الذاكرة المخازن الاحتياطية المرتبطة بها من البيانات والأجهزة (الجمع) التي ترتبط بها مساحة الذاكرة، بالإضافة إلى العميل الذي تشكل جزءًا منه.
PjRtBuffer
يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtBuffer
.
يخزّن المخزن المؤقت البيانات على جهاز بتنسيق يسهُل التعامل معه
داخل المكوّن الإضافي، مثل سمة عناصر MLIR أو تنسيق مصفوفة تعامُل خاصة.
قد يحاول إطار العمل إرسال البيانات إلى جهاز في شكل xla::Literal
،
أي لوسيطة إدخال إلى الوحدة، والتي يجب نسخها (أو استعارتها) لتحميلها إلى
ذاكرة الجهاز. وعندما لا تعود هناك حاجة إلى مورد احتياطي، يتم استدعاء طريقة Delete
من خلال إطار العمل لإخلاء المساحة.
يعرف المخزن المؤقت مساحة الذاكرة التي هي جزء منها، وبإمكانه اكتشاف الأجهزة القادرة على الوصول إليه، غير أن التخزين المؤقت لا يعرف بالضرورة أجهزته.
للتواصل مع إطارات العمل، تعرف الفواصل كيفية التحويل إلى نوع
xla::Literal
والعكس:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
تحتوي واجهات برمجة التطبيقات لإنشاء مخزن مؤقت على دلالات المخزن المؤقت التي تساعد في تحديد ما إذا كان يمكن مشاركة البيانات الحرفية من المخزن المؤقت للمضيف أو نسخها أو تغييرها.
أخيرًا، قد يحتاج المخزن المؤقت إلى الاستمرار لفترة أطول من نطاق تنفيذه، إذا تم تحديده
لمتغيّر في طبقة الإطار x = jit(foo)(10)
. وفي هذه الحالات، تسمح ملفّات التخزين المؤقت بإنشاء مراجع خارجية توفّر مؤشرًا يملكه مؤقتًا
للبيانات التي يحتفظ بها المخزن المؤقت، بالإضافة إلى البيانات الوصفية (أنواع البيانات / أحجام السمات)
لتفسير البيانات الأساسية.
PjRtCompiler
يمكنك الاطّلاع على المرجع الكامل على pjrt_compiler.h > PjRtCompiler
.
توفّر الفئة PjRtCompiler
تفاصيل مفيدة لتنفيذ خلفيات XLA، ولكنها ليست ضرورية لتنفيذ مكوّن إضافي. من الناحية النظرية، تتمثل
مسؤولية PjRtCompiler
أو طريقة PjRtClient::Compile
في
أخذ وحدة إدخال وإرجاع PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
يمكنك الاطّلاع على المرجع الكامل على pjrt_executable.h > PjRtExecutable
،
وpjrt_client.h > PjRtLoadedExecutable
.
يعرف PjRtExecutable
كيفية أخذ عنصر مُجمَّع وخيارات التنفيذ
وتسلسلها أو عكس تسلسلها حتى يمكن تخزين ملف قابل للتنفيذ وتحميله
حسب الحاجة.
PjRtLoadedExecutable
هو ملف تنفيذي مجمَّع في الذاكرة وجاهز
لتشغيل مَعلمات الإدخال، وهو فئة فرعية من PjRtExecutable
.
يتم التفاعل مع الملفات التنفيذية من خلال إحدى طرق Execute
الخاصة بالعميل:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
قبل استدعاء Execute
، سينقل إطار العمل جميع البيانات المطلوبة إلى
PjRtBuffers
التي يملكها العميل الذي ينفذ الإجراء، ولكن يتم إرجاعها لإطار العمل كي تتم
مراجعتها. ويتم بعد ذلك تقديم هذه المخازن المؤقتة كوسيطات لطريقة Execute
.
مفاهيم PJRT
PjRtFutures والعمليات الحسابية غير المتزامنة
إذا تم تنفيذ أي جزء من المكوّن الإضافي بشكل غير متزامن، يجب أن يتم تنفيذ الوظائف المستقبلية بشكل صحيح.
فكِّر في البرنامج التالي:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
يمكن أن يضيف المكوّن الإضافي غير المتزامن عملية الحساب x
إلى "قائمة الانتظار"، ويعرض على الفور
مخزنًا مؤقتًا غير جاهز للقراءة بعد، ولكن سيؤدي التنفيذ إلى ملء
هذا المخزن المؤقت. يمكن أن يستمر التنفيذ في إضافة العمليات الحسابية اللازمة إلى "قائمة الانتظار" بعد x
، والتي
لا تتطلّب x
، بما في ذلك التنفيذ على أجهزة PJRT الأخرى. عند الحاجة إلى قيمة
x
، سيتم حظر التنفيذ إلى أن يعلن المخزن المؤقت عن جاهزيته عبر
المستقبل الذي يعرضه GetReadyFuture
.
يمكن أن تكون العناصر المستقبلية مفيدة لتحديد وقت توفّر عنصر معيّن، بما في ذلك الأجهزة ووحدات التخزين المؤقت.
المفاهيم المتقدّمة
أما إذا تجاوز مجرد تنفيذ واجهات برمجة التطبيقات الأساسية، فسيؤدي ذلك إلى توسيع ميزات JAX التي يمكن استخدامها بواسطة المكوّنات الإضافية. جميع هذه الميزات تتطلّب موافقة المستخدم، عِلمًا أنّه في نماذج JIT المعتادة وسير العمل الذي يمكن تنفيذه بدون تلك الميزات، لا شكّ في أنّه بالنسبة إلى مسار جودة الإنتاج، يجب مراعاة درجة الدعم لأي من هذه الميزات المتاحة في واجهات برمجة تطبيقات PJRT:
- مساحات الذاكرة
- التنسيقات المخصصة
- عمليات التواصل مثل الإرسال/الاستلام
- تفريغ الموارد على المضيف
- التقسيم إلى أجزاء
نموذج التواصل بين إطار عمل PJRT والجهاز
مثال على السجل
في ما يلي سجلّ بالطرق المستدعاة لتحميل المكوّن الإضافي PJRT وتنفيذ y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. في هذه الحالة،
نسجِّل تفاعل JAX مع المكوّن الإضافي PJRT في مرجع StableHLO.
مثال على السجلّ
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)