סקירה כללית על PJRT C++ Device API

רקע

PJRT הוא ממשק ה-API אחיד למכשירים שאנחנו רוצים להוסיף לסביבת ה-ML. החזון לטווח ארוך הוא:

  1. מסגרות (JAX,‏ TF וכו') יקראו ל-PJRT, שיש לו הטמעות ספציפיות למכשיר שהמסגרות לא יכולות לראות.
  2. כל מכשיר מתמקד בהטמעת ממשקי PJRT API, והוא יכול להיות אטום למסגרות.

PJRT מציע גם ממשק API ל-C וגם ממשק API ל-C++. אפשר לחבר את הקוד בשתי השכבות. ב-API ל-C++‎ נעשה שימוש במדדים כדי ליצור מושגים מופשטים, אבל יש לו גם קשרים חזקים יותר לסוגי הנתונים של XLA. הדף הזה מתמקד ב-API ל-C++‎.

רכיבי PJRT

רכיבי PJRT

PjRtClient

מידע מלא זמין בכתובת pjrt_client.h > PjRtClient.

לקוחות מנהלים את כל התקשורת בין המכשיר ל-Framework, ומקפידים על אנקפסולציה של כל המצבים שבהם נעשה שימוש בתקשורת. יש להם קבוצה גנרית של ממשקי API לאינטראקציה עם פלאגין PJRT, והם הבעלים של המכשירים ומרחבי הזיכרון של פלאגין נתון.

PjRtDevice

המקורות המלאים מופיעים במאמרים pjrt_client.h > PjRtDevice ו-pjrt_device_description.h

סיווג מכשיר משמש לתיאור מכשיר אחד. לכל מכשיר יש תיאור שמאפשר לזהות את הסוג שלו (גיבוב ייחודי לזיהוי GPU/CPU/xPU) ואת המיקום שלו בתוך רשת של מכשירים, גם ברמה המקומית וגם ברמה הגלובלית.

למכשירים יש גם מידע על מרחבי הזיכרון המשויכים אליהם ועל הלקוח שבבעלותם.

המכשיר לא בהכרח יודע מהם מאגרי הנתונים בפועל שמשויכים אליו, אבל הוא יכול לברר זאת על ידי בדיקה של מרחבי הזיכרון המשויכים אליו.

PjRtMemorySpace

מידע מלא זמין בכתובת pjrt_client.h > PjRtMemorySpace.

אפשר להשתמש במרחבי זיכרון כדי לתאר מיקום של זיכרון. אפשר לבטל את ההצמדה שלהם, והם יוכלו להופיע בכל מקום, אבל תהיה גישה אליהם רק מהמכשיר שבו הם מוצמדים. לחלופין, אפשר להצמיד אותם, והם יופיעו רק במכשיר ספציפי.

במרחבי הזיכרון ידועים מאגרי הנתונים המשויכים אליהם, המכשירים (רבים) שמרחב הזיכרון משויך אליהם, וכן הלקוח שהוא חלק ממנו.

PjRtBuffer

מידע מלא זמין בכתובת pjrt_client.h > PjRtBuffer.

מאגר נתונים זמני שומר נתונים במכשיר בפורמט כלשהו שקל לעבוד איתו בתוך הפלאגין, כמו מאפיין של רכיבי MLIR או פורמט קניינית של טינסור. מסגרת עשויה לנסות לשלוח נתונים למכשיר בצורת xla::Literal, כלומר עבור ארגומנט קלט למודול, שצריך לשכפל (או ללוות) לזיכרון של המכשיר. כשאין יותר צורך במאגר, ה-framework מפעיל את השיטה Delete כדי לנקות אותו.

למאגר זיכרון יש ידע על מרחב הזיכרון שהוא חלק ממנו, והוא יכול להבין באופן עקיף אילו מכשירים יכולים לגשת אליו, אבל לא בטוח שהוא יודע מהם המכשירים האלה.

כדי לתקשר עם מסגרות, מאגרי נתונים יודעים איך לבצע המרה לסוג xla::Literal וממנו:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

לממשקי ה-API ליצירת מאגר זמני יש סמנטיקה של מאגר זמני, שמאפשרת לקבוע אם אפשר לשתף, להעתיק או לשנות נתונים מילוליים ממאגר הזמני המארח.

לבסוף, יכול להיות שצריך שמאגר נתונים יחזיק מעמד יותר מטווח הביצוע שלו, אם הוא מוקצה למשתנה בשכבת המסגרת x = jit(foo)(10). במקרים כאלה, מאגרי נתונים מאפשרים ליצור הפניות חיצוניות שמספקות פוינטר בבעלות זמנית לנתונים שמאוחסנים במאגר הנתונים, יחד עם מטא-נתונים (dtype / dim sizes) לצורך פרשנות הנתונים הבסיסיים.

PjRtCompiler

מידע מלא זמין בכתובת pjrt_compiler.h > PjRtCompiler.

המחלקה PjRtCompiler מספקת פרטים שימושיים על ההטמעה לקצוות העורפיים של XLA, אבל לא נדרשת הטמעה של פלאגין. בתיאוריה, התפקיד של PjRtCompiler, או של השיטה PjRtClient::Compile, הוא לקבל מודול קלט ולהחזיר PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

מידע מלא זמין במאמרים pjrt_executable.h > PjRtExecutable ו-pjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable יודע איך לקחת קובץ ארטיפקט מקודד ואפשרויות הפעלה ולבצע להם סריאליזציה/דה-סריאליזציה, כדי שאפשר יהיה לאחסן ולטעון קובץ הפעלה לפי הצורך.

PjRtLoadedExecutable הוא קובץ ההפעלה המהדר בזיכרון, שמוכנה להפעלה עם ארגומנטים של קלט. הוא תת-סוג של PjRtExecutable.

המערכת משקללת קובצי הפעלה באחת מהשיטות Execute של הלקוח:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

לפני הקריאה ל-Execute, המסגרת תעביר את כל הנתונים הנדרשים ל-PjRtBuffers שבבעלות הלקוח שמבצע את הפעולה, אבל הנתונים יחזרו למסגרת לצורך עיון. לאחר מכן מאגרי הנתונים האלה מוצגים כארגומנטים ל-method‏ Execute.

מושגים של PJRT

PjRtFutures ומחשוב אסינכרוני

אם חלק כלשהו של פלאגין מיושם באופן אסינכרוני, חובה להטמיע בו נכסי Future בצורה תקינה.

נבחן את התוכנית הבאה:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

פלאגין אסינכרוני יוכל להוסיף את החישוב x לתור, ולהחזיר מיד מאגר נתונים זמני שעדיין לא מוכן לקריאה, אבל הביצוע יאכלס אותו. לאחר x, המערכת יכולה להמשיך להוסיף לתור חישובים נחוצים שלא דורשים את x, כולל ביצוע במכשירי PJRT אחרים. ברגע שיהיה צורך בערך של x, הביצוע ייחסם עד שהמאגר הזמני יכריז שהוא מוכן דרך העתיד שיוחזר על ידי GetReadyFuture.

אפשר להשתמש ב-Futures כדי לקבוע מתי אובייקט יהפוך לזמין, כולל מכשירי אחסון זמני.

מושגים מתקדמים

הרחבה מעבר להטמעת ממשקי ה-API הבסיסיים תרחיב את התכונות של JAX שבהן אפשר להשתמש בפלאגין. כל התכונות האלה מאפשרות הצטרפות, כי ב-JIT הטיפוסי ובתהליכי העבודה יפעלו בלי אותן, אבל לצינור עיבוד נתונים של איכות ייצור, סביר להניח שצריך להשקיע מחשבה מסוימת במידת התמיכה בכל אחת מהתכונות הבאות שנתמכות על ידי ממשקי ה-API של PJRT:

  • מרחבי זיכרון
  • פריסות בהתאמה אישית
  • פעולות תקשורת כמו שליחה/קבלה
  • העברת עומסים מהמארח
  • פיצול

תקשורת אופיינית בין מסגרת PJRT למכשיר

יומן לדוגמה

זהו יומן של השיטות שהופעלו כדי לטעון את הפלאגין PJRT ולהריץ את y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). במקרה הזה, אנחנו מתעדים את האינטראקציה של JAX עם הפלאגין PJRT של StableHLO.

יומן לדוגמה

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)