ייצוג של חלוקה לקטעים

רקע

מטרת הייצוג של חלוקה למקטעים היא לציין איך מתבצעת חלוקת המטריצה למקטעים ביחס לקבוצה של מכשירים זמינים.

הייצוג של חלוקת המטא-נתונים יכול להיות:

  • המשתמש מציין אותם באופן ידני כאילוצים של חלוקה לפלחים על קלט, פלט או נתונים ביניים.
  • הטרנספורמציה מתבצעת לכל פעולה בתהליך ההפצה של חלוקת המטא-נתונים.

סקירה כללית

המבנה הבסיסי

רשת לוגית היא תצוגה רב-ממדית של מכשירים, שמוגדרת לפי רשימה של שמות גדלים של צירים.

הייצוג המוצע של חלוקת המשנה קשור לרשת לוגית ספציפית לפי השם שלה, והוא יכול להפנות רק לשמות של צירים מהרשת הזו. חלוקת הטנסור לחלקים קובעת לאורך אילו צירים (של רשת לוגית ספציפית) כל מאפיין של הטנסור מחולק, בסדר מגדול לקטן. הטנזור משכפל לאורך כל הצירים האחרים של הרשת.

נבחן את הייצוג של חלוקה לקטעים באמצעות טינסור פשוט של דרגה 2 ו-4 מכשירים.

קודם משנים את הצורה של 4 המכשירים [0, 1, 2, 3] למערך דו-מימדי [[0, 1], [2, 3]] כדי ליצור רשת עם 2 צירים:

@mesh_xy = <["x"=2, "y"=2]>

לאחר מכן אפשר לפצל את הטנסור [[a, b], [c, d]] של הרמה 2 באופן הבא:

ייצוג חלוקה לקטעים של טינסור מסדר 2

רכיבים מרכזיים אחרים

  • מאפיינים פתוחים/סגורים – המאפיינים יכולים להיות פתוחים – אפשר לפצל אותם לעוד קטעים בצירים זמינים, או סגורים – הם קבועים ואי אפשר לשנות אותם.
  • צירים שמתבצעת בהם רפליקה באופן מפורש – כל הצירים שלא משמשים לחלוקה של מאפיין עוברים רפליקה באופן משתמע, אבל אפשר לציין בצירוף לחלוקה צירים שמתבצעת בהם רפליקה באופן מפורש, ולכן לא ניתן להשתמש בהם לחלוקה של מאפיין בשלב מאוחר יותר.
  • פיצול צירים וצירים משניים – אפשר לפצל ציר רשת (מלא) למספר צירים משניים, שאפשר להשתמש בהם בנפרד כדי לפצל מאפיין או ליצור רפליקה מפורשת.
  • רשתות רשתות לוגיות מרובות – אפשר לשייך חלוקות שונות לרשתות רשתות לוגיות שונות, שיכולות להיות להן צירים שונים או אפילו סדר שונה של מזהי מכשירים לוגיים.
  • עדיפויות – כדי לפצל תוכנית באופן מצטבר, אפשר לצרף עדיפויות לחלוקות של מאפיינים. העדיפויות קובעות את הסדר שבו אילוצים של חלוקה לפי מאפיין יועברו לאורך המודול.
  • חלוקה של מאפיינים לקטעים (shards) לפי חלוקה – אפשר לפצל מאפיין לפי צירים שמכפלת הגדלים שלהם לא מחלקת את גודל המאפיין.

תכנון מפורט

בקטע הזה נסביר על המבנה הבסיסי ועל כל אחד מהרכיבים המרכזיים.

המבנה הבסיסי

חלוקות המאפיינים מאפשרות לנו לדעת לאורך אילו צירים (או צירי משנה) מתבצעת חלוקה של כל מאפיין בטרנספורמר, מהציר הראשי לציר המשני. כל שאר הצירים שלא מחלקים מאפיין מסוים עוברים שכפול משתמע (או שכפול מפורש).

נתחיל בדוגמה פשוטה ונרחיב אותה ככל שנמשיך לתאר תכונות נוספות.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>

Invariants

  • מספר חלוקות המאפיינים חייב להתאים לדרג (rank) של הטנזור.
  • כל שמות הצירים חייבים להתקיים ברשת שמופיעה בהפניה.
  • צירים או צירי משנה יכולים להופיע פעם אחת בלבד בייצוג של חלוקת המשנה (כל אחד מהם מחלק מאפיין או מבוצע בו רפליקה מפורשת).

מאפיינים פתוחים/סגורים

כל מאפיין של טינסור יכול להיות פתוח או סגור.

פתיחה

מאפיין פתוח זמין להפצה כדי לפצל אותו לעוד צירים, כלומר חלוקת המאפיין לשברים שצוינה לא חייבת להיות חלוקת המאפיין לשברים הסופית. המצב הזה דומה (אבל לא זהה) ל-

אם המאפיין פתוח, מוסיפים ? אחרי הצירים שבהם המאפיין כבר מפוצל (ראו דוגמה בהמשך).

סגור

מאפיין סגור הוא מאפיין שלא זמין להפצה כדי להוסיף לו חלוקה לפלחים נוספת. כלומר, חלוקת המאפיין שצוינה היא חלוקת המאפיין הסופית, ואי אפשר לשנות אותה. דוגמה נפוצה לכך היא האופן שבו GSPMD (בדרך כלל) לא משנה את ארגומנטים הקלט/הפלט של מודול, או האופן שבו ב-jax.jit, הערכים של in_shardings שצוינו על ידי המשתמש הם סטטיים – הם לא יכולים להשתנות.

אפשר להרחיב את הדוגמה שלמעלה כך שתכלול מאפיין פתוח ומאפיין סגור.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>

צירים שמתבצעת להם רפליקה באופן מפורש

קבוצה מפורשת של צירים שבהם מתבצע שכפול של טינסור. אפשר לקבוע שמתבצעת רפליקה משתמעת של טינסור שלא פוצל לפי ציר (כמו jax.sharding.PartitionSpec היום), אבל אם מציינים זאת במפורש, אפשר לוודא שההעברה לא תשתמש בצירים האלה כדי לפצל עוד יותר מאפיין פתוח עם הצירים האלה. באמצעות שכפול משתמע, אפשר לפצל טרנספורמציה לטנזורים נוספים. אבל עם שכפול מפורש, אי אפשר לחלק את הטנזור לאורך הציר הזה.

הסדר של צירים ששכפלתם לא משפיע על אופן האחסון של הנתונים של הטנזור. עם זאת, מטעמי עקביות בלבד, הצירים יישמרו בסדר שבו הם צוינו ברשת ברמה העליונה. לדוגמה, אם הרשת היא:

@mesh_xy = <["c"=2, "a"=2, "b"=2]>

ואנחנו רוצים שהצירים "a" ו-"c" ישוחזרו במפורש, הסדר צריך להיות:

replicated={"c", "a"}

אפשר להרחיב את הדוגמה שלמעלה כך שתכלול ציר שמתבצעת לו רפליקה באופן מפורש.

@mesh_xyz = <["x"=2, "y"=4, "z"=2]>

// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would // be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>

פיצול צירים וצירים משניים

כדי ליצור רשת לוגית של n צירים, מעצבים מחדש מערך של מכשירי ממדים אחדים למערך של n ממדים, שבו כל מאפיין יוצר ציר עם שם שהוגדר על ידי המשתמש.

אפשר לבצע את אותו תהליך במהדר כדי לפצל ציר בגודל k ל-m צירי משנה, על ידי שינוי צורת הרשת מ-[...,k,...] ל-[...,k1,...,km,...].

למה בחרנו לעשות זאת?

כדי להבין את המניע מאחורי פיצול צירים, נבחן את הדוגמה הבאה:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>

אנחנו רוצים לפצל את התוצאה של שינוי הצורה כך שלא תהיה צורך בתקשורת (כלומר, להשאיר את הנתונים במקום שבו הם נמצאים). מכיוון שהגודל של "x" גדול מהמאפיין הראשון של התוצאה, צריך לפצל את הציר לשני צירי משנה, "x.0" ו-"x.1", בגודל 2 כל אחד, ולחלק את המאפיין הראשון ל-"x.0" ואת המאפיין השני ל-"x.1".

חלוקה לפלחים של קלט/פלט של פונקציות

יכול להיות שבמהלך ההעברה, קלט או פלט של הפונקציה הראשית יתחלקו לפי ציר משנה. זה יכול להיות בעיה במסגרות מסוימות, שבהן אי אפשר להביע חלוקות כאלה כדי להחזיר אותן למשתמש (לדוגמה, ב-JAX אי אפשר להביע צירים משניים באמצעות jax.sharding.NamedSharding).

יש לנו כמה אפשרויות לטיפול במקרים כאלה:

  • לאפשר את הפיצול ולהחזיר אותו בפורמט אחר (למשל jax.sharding.PositionalSharding במקום jax.sharding.NamedSharding ב-JAX).
  • אסור להשתמש ב-Disallow ובצירי משנה של all-gather שמחלקים את הקלט/הפלט.

בשלב זה אנחנו מאפשרים צירים משניים בקלט/פלט בצינור עיבוד הנתונים להפצה. נשמח לדעת אם יש לך דרך להשבית את האפשרות הזו.

ייצוג

בדומה לאופן שבו אפשר להפנות לצירים מלאים ספציפיים מהעיגול לפי השם שלהם, אפשר להפנות לצירים משנה ספציפיים לפי הגודל שלהם ולפי המכפלה של כל הגדלים של צירי המשנה (באותו שם ציר) שמשמאל להם (שחשובים להם) .

כדי לחלץ ציר משנה ספציפי בגודל k מציר מלא "x" בגודל n, אנחנו משנים את הצורה של n (במערך) ל-[m, k, n/(m*k)] ומשתמשים במאפיין השני כציר המשנה. כך אפשר לציין ציר משנה באמצעות שני מספרים, m ו-k, ואנחנו משתמשים בסימון המצומצם הבא כדי לציין צירי משנה: "x":(m)k.

  • m>=1 הוא הגודל המקדים של ציר המשנה הזה (m צריך להיות מחלק של n). הגודל המקדים הוא המכפלה של כל גדלי צירי המשנה שמשמאל לציר המשנה הזה (אם הוא שווה ל-1, סימן שאין אף אחד. אם הוא גדול מ-1, הוא תואם לציר משנה יחיד או לכמה צירי משנה).

  • הערך של k>1 הוא הגודל בפועל של ציר המשנה הזה (k צריך להיות מחלק של n).

  • n/(m*k) הוא post-size. הוא המכפלה של כל הגדלים של צירי המשנה שמימין לציר המשנה הזה (אם הוא שווה ל-1, סימן שאין צירי משנה. אם הוא גדול מ-1, הוא מתאים לציר משנה יחיד או לכמה צירי משנה).

עם זאת, מספר הצירים המשניים האחרים לא משנה כשמשתמשים בציר משני ספציפי "x":(m)k, ואין צורך להפנות לציר משני אחר בחלוקת הטנסור אם הוא לא מחלק מאפיין או שהוא משכפל באופן מפורש.

נחזור לדוגמה בקטע הסיבה. אפשר לפצל את התוצאה כך:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
    : (tensor<8xf32>) -> tensor<2x4xf32>

הנה דוגמה נוספת לציר מפוצל שבו נעשה שימוש רק בחלק מצירי המשנה שלו.

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Axis "y" is effectively split into 3 sub-axes denoted as
//   "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>

באופן דומה, שתי החלקות הנתונים הבאות זהות מבחינה סמנטית. אפשר לחשוב על mesh_xy כפיצול של mesh_full.

@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>

sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>

צירים משניים שמתבצעת להם רפליקה באופן מפורש

בנוסף לשימוש בצירים משניים כדי לפצל מאפיינים, אפשר גם לסמן אותם ככאלה שצריך להעתיק אותם באופן מפורש. אנחנו מאפשרים זאת בייצוג כי צירי משנה מתנהגים בדיוק כמו צירים מלאים. כלומר, כשמחלקים מאפיין לפי ציר משנה של ציר "x", צירי המשנה האחרים של "x" עוברים שכפול מרומז, ולכן אפשר לשחזר אותם באופן מפורש כדי לציין שציר משנה חייב להישאר כפול ולא ניתן להשתמש בו כדי לפצל מאפיין.

לדוגמה:

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>

יש למיין צירי משנה משכפלים של אותו ציר מלא בסדר עולה לפי הגודל שהוגדר מראש, לדוגמה:

replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}

Invariants

  • צירי משנה שמצוינים בחלוקה של טינסור לא יכולים לחפוף, למשל "x":(1)4 ו-"x":(2)4 חופפים.

  • צירי משנה שמצוינים בחלוקה של טינסור חייבים להיות גדולים ככל האפשר. כלומר, אם בחלוקה של מאפיין יש שני צירי משנה סמוכים A ו-B בסדר, או אם צירי המשנה A ו-B מועתקים באופן מפורש, הם לא יכולים להיות רצופים, למשל "x":(1)2 ו-"x":(2)4, כי אפשר להחליף אותם ב-"x":(1)8 יחיד.

כמה רשתות לוגיות

רשת לוגית אחת היא תצוגה רב-ממדית של מכשירים. יכול להיות שנצטרך כמה תצוגות של המכשירים כדי לייצג את החלוקה לפלחים, במיוחד במקרים של הקצאות שרירותיות של מכשירים.

לדוגמה, jax.sharding.PositionalSharding לא כולל רשת לוגית משותפת אחת. בשלב זה, GSPMD תומך בכך באמצעות HloSharding, שבו הייצוג יכול להיות רשימה מסודרת של מכשירים וגדלים של מאפיינים, אבל אי אפשר לייצג אותו באמצעות חלוקת הצירים שמתוארת למעלה.

כדי להתגבר על המגבלה הזו ולטפל במקרים קיצוניים קיימים, אנחנו מגדירים כמה רשתות לוגיות ברמה העליונה של התוכנית. לכל רשת יכול להיות מספר שונה של צירים עם שמות שונים, וגם הקצאה שרירותית משלה לאותה קבוצה של מכשירים. כלומר, כל רשת מתייחסת לאותה קבוצה של מכשירים (לפי המזהה הלוגי הייחודי שלהם) אבל בסדר שרירותי, בדומה לייצוג של GSMPD.

כל ייצוג של חלוקה לקטעים מקושר לרשת לוגית ספציפית, ולכן הוא יפנה רק לצירים מהרשת הזו.

אפשר להשתמש בטנסור שהוקצה לרשת לוגית אחת בפעולה שהוקצתה לרשת אחרת, על ידי חלוקה מחדש של הטנסור באופן פשוט כך שיתאים לרשת היעד. ב-GSPMD, זהו בדרך כלל הפתרון לבעיות של רשתות רשתות מתנגשות.

בהמשך מפורטות שתי דוגמאות:

משתמשים יכולים לציין כמה רשתות עם צירים שונים בשמות שונים (למשל באמצעות jax.sharding.NamedSharding), שיש להם את אותו סדר של מכשירים. בדוגמה הזו, הערך של <@mesh_0, "b"> זהה לערך של <@mesh_1, "z">..

@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}

עדיפויות

העדיפות מאפשרת לתת עדיפות להחלטות מסוימות לגבי חלוקה ותפוצה על פני החלטות אחרות, ומאפשרת חלוקה מצטברת של תוכנית.

עדיפויות הן ערכים שמצורפים לחלק מהמימדים או לכל המימדים של ייצוג חלוקה (לאקסים משוחזרים אין עדיפויות).

לדוגמה:

@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>

//                                    |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>

העדיפויות נותנות למשתמשים שליטה מפורטת יותר על ההעברה, למשל, קודם ביצוע פעולות במקביל באצווה, אחר כך Megatron ולבסוף חלוקה לפלחים של ZeRO. כך אפשר להבטיח בצורה טובה יותר מה מחולק למחיצות, ולאפשר ניפוי באגים טוב יותר באמצעות אסטרטגיות פירוט מפורט יותר של חלוקה למחיצות (אפשר לראות איך נראה התוכנית אחרי שמשתמשים רק ב-megatron בנפרד).

אנחנו מאפשרים לצרף עדיפות לכל חלוקה של מאפיין (0 כברירת מחדל). המשמעות היא שכל החלוקות עם העדיפות <i יועברו לכל התוכנית לפני החלוקות עם העדיפות i.

גם אם לחלוקה לפלחים יש מאפיין פתוח עם עדיפות נמוכה יותר, למשל: {"z",?}p2, הוא לא ישתנה על ידי חלוקה אחרת של טינסור עם עדיפות גבוהה יותר במהלך ההעברה. עם זאת, אפשר לפצל עוד מאפיין פתוח כזה אחרי שכל הפיצולים בעדיפות גבוהה יותר מופצים.

במילים אחרות, העדיפויות לא קובעות איזה חלוקה למקטעים של מאפיינים חשובה יותר מאחרת – הן קובעות את הסדר שבו קבוצות נפרדות של חלוקות למקטעים של מאפיינים צריכות להופיע בתוכנית כולה, ואת האופן שבו צריך לפתור קונפליקטים בטנסורים ביניים ללא הערות.

Invariants

  • רמות העדיפות מתחילות ב-0 (העדיפות הגבוהה ביותר) וממשיכות לגדול (כדי לאפשר למשתמשים להוסיף ולהסיר עדיפות בקלות, אנחנו מאפשרים פערים בין רמות העדיפות. לדוגמה, נעשה שימוש ב-p0 וב-p2 אבל לא ב-p1).

  • חלוקה למקטעים של מאפיין סגור ריק (כלומר, {}), לא צריכה להיות עדיפות, כי לא תהיה לה השפעה.

חלוקה של מאפיינים לקטעים – חלוקה לחלקים

אפשר לפצל מימד בגודל d לפי צירים שמכפלת הגדלים שלהם היא n, כך ש-d לא מתחלק ב-n (בפועל, יהיה צורך להוסיף מילוי למימד).

לדוגמה:

@mesh_xy = <["x"=8, "y"=2, "z"=3]>

sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>

דקדוק

כל רשת לוגית מוגדרת באופן הבא:

@mesh_name = <mesh_axis_1,...,mesh_axis_n>

mesh_axis ::= axis_name=axis_size

axis_name ::= str
axis_size ::= int

לייצוג של חלוקה לפלחים יהיה המבנה הבא עבור טינסור בעל דרגה r:

sharding<@mesh_name, dim_shardings, replicated=replicated_axes}

mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}

dim_sharding ::=
  {axis_1,...,axis_k} |  // closed dimension
  {axis_1,...,axis_k,?}  // open dimension

axis ::=
  axis_name  |   // a full axis
  sub_axis             // a sub axis

axis_name ::= str

sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int