نظرة عامة على واجهة برمجة التطبيقات لأجهزة PJRT C++

الخلفية

PJRT هي واجهة برمجة تطبيقات موحّدة للأجهزة نريد إضافتها إلى منظومة تعلُّم الآلة. تتمثّل الرؤية طويلة الأمد في ما يلي:

  1. ستطلب الأُطر (JAX وTF وما إلى ذلك) PJRT، التي تتضمّن عمليات تنفيذ خاصة بالجهاز وغير شفافة للأُطر.
  2. يركّز كل جهاز على تنفيذ واجهات برمجة تطبيقات PJRT، ويمكن أن يكون غير شفاف للأُطر.

توفّر PJRT واجهة برمجة تطبيقات C وواجهة برمجة تطبيقات C++. يمكنك استخدام أي من الطبقتين، إذ تستخدم واجهة برمجة التطبيقات C++‎ فئات لتجريد بعض المفاهيم، ولكنها ترتبط أيضًا بشكل أكبر بأنواع بيانات XLA. تركّز هذه الصفحة على واجهة برمجة التطبيقات C++.

مكوّنات PJRT

مكوّنات PJRT

PjRtClient

يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtClient.

تتولّى البرامج جميع عمليات التواصل بين الجهاز والإطار، كما أنّها تغلف جميع الحالات المستخدَمة في التواصل. تتضمّن هذه الأجهزة مجموعة عامة من واجهات برمجة التطبيقات للتفاعل مع مكوّن إضافي في PJRT، وهي تملك الأجهزة ومساحات الذاكرة لمكوّن إضافي معيّن.

PjRtDevice

المراجع الكاملة على pjrt_client.h > PjRtDevice وpjrt_device_description.h

يتم استخدام فئة الجهاز لوصف جهاز واحد. يتضمّن الجهاز وصفًا يساعد في تحديد نوعه (تجزئة فريدة لتحديد وحدة معالجة الرسومات أو وحدة المعالجة المركزية أو وحدة المعالجة الموسّعة)، وموقعه الجغرافي ضمن شبكة من الأجهزة على المستويَين المحلي والعالمي.

تعرف الأجهزة أيضًا مساحات الذاكرة المرتبطة بها والعميل الذي يملكها.

لا يعرف الجهاز بالضرورة مخازن البيانات الفعلية المرتبطة به، ولكن يمكنه معرفة ذلك من خلال البحث في مساحات الذاكرة المرتبطة به.

PjRtMemorySpace

يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtMemorySpace.

يمكن استخدام مساحات الذاكرة لوصف موقع جغرافي للذاكرة. يمكن إما إلغاء تثبيتها، ما يتيح لها أن تكون متاحة على أي جهاز، أو يمكن تثبيتها، ما يعني أنّها يجب أن تكون متاحة على جهاز معيّن.

تعرف مساحات الذاكرة المخازن المؤقتة المرتبطة بها من البيانات، والأجهزة (بصيغة الجمع) المرتبطة بها، بالإضافة إلى التطبيق الذي تشكّل جزءًا منه.

PjRtBuffer

يمكنك الاطّلاع على المرجع الكامل على pjrt_client.h > PjRtBuffer.

يحتوي المخزن المؤقت على بيانات على جهاز بتنسيق يسهل التعامل معه داخل المكوّن الإضافي، مثل سمة عناصر MLIR أو تنسيق موتر خاص. قد يحاول إطار العمل إرسال بيانات إلى جهاز على شكل xla::Literal، أي كمعلَمة إدخال للوحدة النمطية، ويجب استنساخها (أو استعارتها) إلى ذاكرة الجهاز. وبعد أن يصبح المخزن المؤقت غير مطلوب، يستدعي إطار العمل الطريقة Delete لتنظيفه.

تعرف المخزن المؤقت مساحة الذاكرة التي يشكّل جزءًا منها، ويمكنه بشكل غير مباشر تحديد الأجهزة التي يمكنها الوصول إليه، ولكن لا تعرف المخازن المؤقتة بالضرورة أجهزتها.

للتواصل مع إطارات العمل، تعرف المخازن المؤقتة كيفية التحويل من وإلى نوع xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

تحتوي واجهات برمجة التطبيقات لإنشاء مخزن مؤقت على دلالات المخزن المؤقت التي تساعد في تحديد ما إذا كان يمكن مشاركة البيانات الحرفية من المخزن المؤقت للمضيف أو نسخها أو تعديلها.

أخيرًا، قد تحتاج المخزن المؤقت إلى مدة أطول من نطاق تنفيذه، إذا تم تعيينه لمتغير في طبقة إطار العمل x = jit(foo)(10)، وفي هذه الحالات، تسمح المخازن المؤقتة بإنشاء مراجع خارجية توفّر مؤشرًا مملوكًا مؤقتًا للبيانات التي يحتفظ بها المخزن المؤقت، بالإضافة إلى البيانات الوصفية (نوع البيانات / أحجام الأبعاد) لتفسير البيانات الأساسية.

PjRtCompiler

يمكنك الاطّلاع على المرجع الكامل على pjrt_compiler.h > PjRtCompiler.

يوفّر الصف PjRtCompiler تفاصيل تنفيذ مفيدة لخلفيات XLA، ولكن ليس من الضروري أن ينفّذ المكوّن الإضافي هذا الصف. من الناحية النظرية، تتمثل مسؤولية PjRtCompiler أو طريقة PjRtClient::Compile في تلقّي وحدة إدخال وعرض PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

يمكنك الاطّلاع على المرجع الكامل على pjrt_executable.h > PjRtExecutable وpjrt_client.h > PjRtLoadedExecutable.

تعرف PjRtExecutable كيفية أخذ عنصر تم تجميعه وخيارات التنفيذ وتسلسله/إلغاء تسلسله حتى يمكن تخزين ملف تنفيذي وتحميله حسب الحاجة.

PjRtLoadedExecutable هو الملف التنفيذي المجمّع والمخزّن في الذاكرة والذي يكون جاهزًا لتنفيذ وسيطات الإدخال، وهو فئة فرعية من PjRtExecutable.

يتم التفاعل مع الملفات التنفيذية من خلال إحدى طرق Execute الخاصة بالعميل:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

قبل استدعاء Execute، سينقل إطار العمل جميع البيانات المطلوبة إلى PjRtBuffers الذي يملكه العميل المنفِّذ، ولكن سيتم إرجاعها إلى إطار العمل للرجوع إليها. يتم بعد ذلك تقديم هذه المخازن المؤقتة كوسيطات إلى الطريقة Execute.

مفاهيم PJRT

العقود الآجلة والحسابات غير المتزامنة

إذا تم تنفيذ أي جزء من مكوّن إضافي بشكل غير متزامن، يجب تنفيذ عمليات مستقبلية بشكل سليم.

ضع في اعتبارك البرنامج التالي:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

سيتمكّن المكوّن الإضافي غير المتزامن من وضع عملية الحساب في قائمة الانتظار x، ثم إرجاع مخزن مؤقت على الفور، مع العلم أنّه لن يكون جاهزًا للقراءة بعد، ولكن سيتم ملؤه أثناء التنفيذ. يمكن أن يستمر التنفيذ في إضافة العمليات الحسابية اللازمة إلى قائمة الانتظار بعد x، والتي لا تتطلب x، بما في ذلك التنفيذ على أجهزة PJRT الأخرى. عند الحاجة إلى قيمة x، سيتم حظر التنفيذ إلى أن يعلن المخزن المؤقت عن جاهزيته من خلال القيمة المستقبلية التي تعرضها GetReadyFuture.

يمكن أن تكون القيم المستقبلية مفيدة لتحديد الوقت الذي يصبح فيه أحد العناصر متاحًا، بما في ذلك الأجهزة والمخازن المؤقتة.

المفاهيم المتقدّمة

سيؤدي التوسّع إلى ما هو أبعد من تنفيذ واجهات برمجة التطبيقات الأساسية إلى توسيع ميزات JAX التي يمكن أن تستخدمها إحدى الإضافات. جميع هذه الميزات اختيارية، بمعنى أنّ سير عمل JIT والتنفيذ النموذجي سيعمل بدونها، ولكن للحصول على مسار إنتاج عالي الجودة، من المفترض أن يتم التفكير في مستوى التوافق مع أي من هذه الميزات التي توفّرها واجهات برمجة التطبيقات PJRT:

  • مساحات الذاكرة
  • التنسيقات المخصصة
  • عمليات التواصل، مثل الإرسال والاستلام
  • نقل المهام إلى المضيف
  • التقسيم إلى أجزاء

التواصل النموذجي بين إطار عمل PJRT والجهاز

مثال على السجلّ

في ما يلي سجلّ للطُرق التي تم استدعاؤها لتحميل إضافة PJRT وتنفيذ y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). في هذه الحالة، نسجّل تفاعل JAX مع المكوّن الإضافي StableHLO Reference PJRT.

مثال على السجلّ

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)