סקירה כללית על PJRT C++ Device API

רקע

PJRT הוא ממשק ה-API אחיד למכשירים שאנחנו רוצים להוסיף לסביבת ה-ML. החזון לטווח ארוך הוא:

  1. מסגרות (JAX,‏ TF וכו') יקראו ל-PJRT, שיש לו הטמעות ספציפיות למכשיר שהמסגרות לא יכולות לראות.
  2. כל מכשיר מתמקד בהטמעת ממשקי PJRT API, והוא יכול להיות אטום למסגרות.

PJRT מציע גם ממשק API ל-C וגם ממשק API ל-C++. אפשר לחבר את הקוד בשתי השכבות. ב-API ל-C++‎ נעשה שימוש במדדים כדי ליצור מושגים מופשטים, אבל יש לו גם קשרים חזקים יותר לסוגי הנתונים של XLA. הדף הזה מתמקד ב-API ל-C++‎.

רכיבי PJRT

רכיבי PJRT

PjRtClient

מידע מלא זמין בכתובת pjrt_client.h > PjRtClient.

לקוחות מנהלים את כל התקשורת בין המכשיר ל-Framework, ומקפידים על אנקפסולציה של כל המצבים שבהם נעשה שימוש בתקשורת. יש להם קבוצה גנרית של ממשקי API ליצירת אינטראקציה עם פלאגין PJRT, והם הבעלים של המכשירים ומרחבי הזיכרון של פלאגין נתון.

PjRtDevice

המקורות המלאים מופיעים במאמרים pjrt_client.h > PjRtDevice ו-pjrt_device_description.h

סיווג מכשיר משמש לתיאור מכשיר יחיד. לכל מכשיר יש תיאור שמאפשר לזהות את הסוג שלו (גיבוב ייחודי לזיהוי GPU/CPU/xPU) ואת המיקום שלו בתוך רשת של מכשירים, גם ברמה המקומית וגם ברמה הגלובלית.

המכשירים גם יודעים מהם מרחבי הזיכרון המשויכים אליהם ואת הלקוח שבבעלותו הם נמצאים.

המכשיר לא בהכרח יודע מהם מאגרי הנתונים בפועל שמשויכים אליו, אבל הוא יכול לברר זאת על ידי בדיקה של מרחבי הזיכרון המשויכים אליו.

PjRtMemorySpace

מידע מלא זמין בכתובת pjrt_client.h > PjRtMemorySpace.

אפשר להשתמש במרחבי זיכרון כדי לתאר מיקום של זיכרון. אפשר לבטל את ההצמדה שלהם, והם יוכלו להופיע בכל מקום, אבל תהיה גישה אליהם רק מהמכשיר שבו הם מוצמדים. לחלופין, אפשר להצמיד אותם, והם יופיעו רק במכשיר ספציפי.

במרחבי הזיכרון ידועים מאגרי הנתונים המשויכים אליהם, המכשירים (רבים) שמרחב הזיכרון משויך אליהם, וכן הלקוח שהוא חלק ממנו.

PjRtBuffer

מידע מלא זמין בכתובת pjrt_client.h > PjRtBuffer.

מאגר נתונים זמני שומר נתונים במכשיר בפורמט כלשהו שקל לעבוד איתו בתוך הפלאגין, כמו מאפיין של רכיבי MLIR או פורמט קניינית של טינסור. מסגרת עשויה לנסות לשלוח נתונים למכשיר בצורת xla::Literal, כלומר עבור ארגומנט קלט למודול, שצריך לשכפל (או ללוות) לזיכרון של המכשיר. כשאין יותר צורך במאגר, ה-framework מפעיל את השיטה Delete כדי לנקות אותו.

למאגר זיכרון יש ידע על מרחב הזיכרון שהוא חלק ממנו, והוא יכול להבין באופן עקיף אילו מכשירים יכולים לגשת אליו, אבל לא בטוח שהוא יודע מהם המכשירים האלה.

כדי לתקשר עם מסגרות, מאגרי נתונים יודעים איך לבצע המרה לסוג xla::Literal וממנו:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

לממשקי ה-API ליצירת מאגר זמני יש סמנטיקה של מאגר זמני, שמאפשרת לקבוע אם אפשר לשתף, להעתיק או לשנות נתונים מילוליים ממאגר הזמני המארח.

לבסוף, יכול להיות שצריך שמאגר נתונים יחזיק מעמד יותר מטווח הביצוע שלו, אם הוא מוקצה למשתנה בשכבת המסגרת x = jit(foo)(10). במקרים כאלה, מאגרי נתונים מאפשרים ליצור הפניות חיצוניות שמספקות פוינטר בבעלות זמנית לנתונים שמאוחסנים במאגר הנתונים, יחד עם מטא-נתונים (dtype / dim sizes) לצורך פרשנות הנתונים הבסיסיים.

PjRtCompiler

מידע מלא זמין בכתובת pjrt_compiler.h > PjRtCompiler.

הכיתה PjRtCompiler מספקת פרטי הטמעה שימושיים לקצוות עורפי של XLA, אבל אין צורך להטמיע אותה בפלאגין. בתיאוריה, התפקיד של PjRtCompiler, או של השיטה PjRtClient::Compile, הוא לקבל מודול קלט ולהחזיר PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

מידע מלא זמין במאמרים pjrt_executable.h > PjRtExecutable ו-pjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable יודע איך לקחת קובץ ארטיפקט מקודד ואפשרויות הפעלה ולבצע להם סריאליזציה/דה-סריאליזציה, כדי שאפשר יהיה לאחסן ולטעון קובץ הפעלה לפי הצורך.

PjRtLoadedExecutable הוא קובץ ההפעלה המהדר בזיכרון, שמוכנה להפעלה עם ארגומנטים של קלט. הוא תת-סוג של PjRtExecutable.

ממשק עם קובצי הפעלה מתבצע באמצעות אחת מהשיטות Execute של הלקוח:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

לפני הקריאה ל-Execute, המסגרת תעביר את כל הנתונים הנדרשים ל-PjRtBuffers שבבעלות הלקוח שמבצע את הפעולה, אבל הנתונים יחזרו למסגרת לצורך עיון. לאחר מכן מאגרי הנתונים האלה מוצגים כארגומנטים ל-method‏ Execute.

מושגים של PJRT

PjRtFutures וחישוב אסינכרוני

אם חלק כלשהו של פלאגין מיושם באופן אסינכרוני, חובה להטמיע בו נכסי Future בצורה תקינה.

נבחן את התוכנית הבאה:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

פלאגין אסינכררוני יוכל להוסיף את החישוב x לתור, ולהחזיר באופן מיידי מאגר נתונים שעדיין לא מוכן לקריאה, אבל הביצוע יאכלס אותו. לאחר x, המערכת יכולה להמשיך להוסיף לתור חישובים נחוצים שלא דורשים את x, כולל ביצוע במכשירי PJRT אחרים. כשהערך של x יידרש, הביצוע ייחסם עד שהמאגר יכריז שהוא מוכן באמצעות הערך העתידי שמוחזר על ידי GetReadyFuture.

אפשר להשתמש ב-Futures כדי לקבוע מתי אובייקט יהפוך לזמין, כולל מכשירי אחסון זמני.

מושגים מתקדמים

הרחבה מעבר להטמעת ממשקי ה-API הבסיסיים תרחיב את התכונות של JAX שבהן אפשר להשתמש בפלאגין. כל התכונות האלה הן אופציונליות, במובן ש-JIT ותהליך העבודה של ההפעלה יפעלו בלעדיהן, אבל צינור עיבוד נתונים באיכות ייצור צריך לקבל מחשבה בנוגע לרמת התמיכה בכל אחת מהתכונות האלה שנתמכות בממשקי PJRT API:

  • מרחבי זיכרון
  • פריסות בהתאמה אישית
  • פעולות תקשורת כמו שליחה/קבלה
  • העברת עומסים מהמארח
  • פיצול

תקשורת אופיינית בין מסגרת PJRT למכשיר

יומן לדוגמה

זהו יומן של השיטות שהופעלו כדי לטעון את הפלאגין PJRT ולהריץ את y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). במקרה הזה, אנחנו מתעדים את האינטראקציה של JAX עם הפלאגין PJRT של StableHLO.

יומן לדוגמה

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)