רקע
PJRT הוא Device API אחיד שאנחנו רוצים להוסיף למערכת האקולוגית של ML. החזון לטווח הארוך הוא:
- מסגרות (JAX, TF וכו') יפעילו את PJRT, שיש לו הטמעות ספציפיות למכשיר שהן אטומות למסגרות.
- כל מכשיר מתמקד בהטמעה של ממשקי PJRT API, ויכול להיות שהוא לא שקוף למסגרות.
PJRT מציע גם C API וגם C++ API. אפשר להשתמש בשכבה כלשהי. ה-API של C++ משתמש במחלקות כדי לבצע הפשטה של חלק מהמושגים, אבל יש לו גם קשר חזק יותר לסוגי הנתונים של XLA. הדף הזה מתמקד ב-C++ API.
רכיבים של PJRT
PjRtClient
ההפניה המלאה זמינה בכתובת pjrt_client.h > PjRtClient.
הלקוחות מנהלים את כל התקשורת בין המכשיר לבין ה-Framework, ומכילים את כל המצב שמשמש בתקשורת. יש להם קבוצה גנרית של ממשקי API (API) לאינטראקציה עם תוסף PJRT, והם הבעלים של המכשירים ושל מרחבי הזיכרון של תוסף נתון.
PjRtDevice
סימוכין מלאים ב-pjrt_client.h > PjRtDevice,
וב-pjrt_device_description.h
מחלקת מכשירים משמשת לתיאור של מכשיר יחיד. למכשיר יש תיאור מכשיר שעוזר לזהות את הסוג שלו (גיבוב ייחודי לזיהוי GPU/CPU/xPU), ומיקום בתוך רשת של מכשירים באופן מקומי וגלובלי.
המכשירים גם יודעים מהם המרחבים שמשויכים להם ומי הלקוח שבבעלותו הם נמצאים.
מכשיר לא בהכרח יודע את המאגרים של הנתונים בפועל שמשויכים אליו, אבל הוא יכול להבין את זה על ידי עיון במרחבי הזיכרון שמשויכים אליו.
PjRtMemorySpace
ההפניה המלאה זמינה בכתובת pjrt_client.h > PjRtMemorySpace.
אפשר להשתמש במרחבי זיכרון כדי לתאר מיקום בזיכרון. אפשר לבטל את ההצמדה שלהם, ואז הם יכולים להיות בכל מקום אבל אפשר לגשת אליהם ממכשיר, או שאפשר להצמיד אותם ואז הם חייבים להיות במכשיר ספציפי.
אזורי הזיכרון יודעים מהם המאגרים המשויכים של הנתונים, המכשירים (ברבים) שאזור הזיכרון משויך אליהם והלקוח שהוא חלק ממנו.
PjRtBuffer
ההפניה המלאה זמינה בכתובת pjrt_client.h > PjRtBuffer.
מאגר נתונים זמני מכיל נתונים במכשיר בפורמט מסוים שקל לעבוד איתו בתוך התוסף, כמו מאפיין של רכיבי MLIR או פורמט טנסור קנייני.
יכול להיות שפריימוורק ינסה לשלוח נתונים למכשיר בצורה של xla::Literal, כלומר כארגומנט קלט למודול, שצריך לשכפל (או להשאיל) לזיכרון של המכשיר. כבר לא צריך את המאגר, המערכת מפעילה את השיטה Delete כדי לנקות אותו.
מאגר נתונים זמני יודע את מרחב הזיכרון שהוא חלק ממנו, ויכול להבין באופן עקיף לאילו מכשירים יש גישה אליו, אבל מאגרי נתונים זמניים לא בהכרח יודעים את המכשירים שלהם.
כדי לתקשר עם מסגרות, מאגרי נתונים יודעים איך להמיר לסוג xla::Literal וממנו:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
לממשקי API ליצירת מאגר זמני יש סמנטיקה של מאגר זמני, שעוזרת לקבוע אם אפשר לשתף, להעתיק או לשנות נתונים מילוליים מהמאגר הזמני של המארח.
לבסוף, יכול להיות שמאגר יצטרך להיות פעיל למשך זמן ארוך יותר מההיקף של ההרצה שלו, אם הוא מוקצה למשתנה בשכבת המסגרת x = jit(foo)(10). במקרים כאלה, מאגרים מאפשרים ליצור הפניות חיצוניות שמספקות מצביע בבעלות זמנית לנתונים שמוחזקים במאגר, יחד עם מטא-נתונים (dtype / dim sizes) לפרשנות של הנתונים הבסיסיים.
PjRtCompiler
ההפניה המלאה זמינה בכתובת pjrt_compiler.h > PjRtCompiler.
המחלקות PjRtCompiler מספקות פרטי הטמעה שימושיים עבור קצה העורף של XLA, אבל הן לא נחוצות להטמעה של תוסף. באופן תיאורטי, האחריות של PjRtCompiler או של השיטה PjRtClient::Compile היא לקבל מודול קלט ולהחזיר PjRtLoadedExecutable.
PjRtExecutable / PjRtLoadedExecutable
ההפניה המלאה מופיעה בכתובות pjrt_executable.h > PjRtExecutable ו-pjrt_client.h > PjRtLoadedExecutable.
PjRtExecutable יודע איך לקחת ארטיפקט שעבר קומפילציה ואפשרויות הפעלה, ולבצע סריאליזציה/דה-סריאליזציה שלהם כדי שאפשר יהיה לשמור קובץ הפעלה ולטעון אותו לפי הצורך.
PjRtLoadedExecutable הוא קובץ הפעלה שעבר קומפילציה בזיכרון ומוכן לקבל ארגומנטים של קלט כדי להפעיל אותו. הוא מחלקת משנה של PjRtExecutable.
הפעלת קובצי הפעלה מתבצעת באמצעות אחת מהשיטות של הלקוח Execute:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
לפני הקריאה ל-Execute, המסגרת תעביר את כל הנתונים הנדרשים אל Execute שבבעלות הלקוח המבצע, אבל יוחזרו למסגרת לצורך הפניה.PjRtBuffers לאחר מכן, המאגרים האלה מועברים כארגומנטים ל-method Execute.
מושגים ב-PJRT
חישובים אסינכרוניים וחוזים עתידיים
אם חלק מהפלאגין מיושם באופן אסינכרוני, חובה ליישם את העתידים בצורה נכונה.
נניח שיש לכם את התוכנית הבאה:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
תוסף אסינכרוני יוכל להוסיף את החישוב x לתור, ולהחזיר מיד מאגר שלא מוכן עדיין לקריאה, אבל ההפעלה תאכלס אותו. הביצוע יכול להמשיך להוסיף ל-queue חישובים נדרשים אחרי x, שלא דורשים x, כולל ביצוע במכשירי PJRT אחרים. כשצריך את הערך של x, הביצוע ייחסם עד שהמאגר יצהיר על עצמו שהוא מוכן באמצעות האובייקט Future שמוחזר על ידי GetReadyFuture.
אפשר להשתמש ב-Futures כדי לקבוע מתי אובייקט הופך לזמין, כולל מכשירים ומאגרי נתונים זמניים.
מושגים מתקדמים
הוספת API מעבר ל-API הבסיסי תרחיב את הפיצ'רים של JAX שאפשר להשתמש בהם באמצעות פלאגין. כל התכונות האלה הן אופציונליות, כלומר תהליך עבודה רגיל של JIT והרצה יפעל בלעדיהן, אבל כדי ליצור צינור באיכות ייצור, כדאי לחשוב על מידת התמיכה בכל אחת מהתכונות האלה שנתמכות על ידי ממשקי ה-API של PJRT:
- מרחבי זיכרון
- פריסות בהתאמה אישית
- פעולות תקשורת כמו שליחה וקבלה
- העברת עומס מהמארח
- פיצול
תקשורת אופיינית בין מכשיר לבין מסגרת PJRT
יומן לדוגמה
הנה יומן של ה-methods שהופעלו כדי לטעון את הפלאגין PJRT ולהריץ את y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). במקרה הזה, אנחנו מתעדים את האינטראקציה של JAX עם פלאגין StableHLO Reference PJRT.
יומן לדוגמה
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)