התפשטות

סקירה כללית

בהפצה של חלוקה למקטעים נעשה שימוש בחלוקות למקטעים שהמשתמשים ציינו כדי להסיק את החלוקות למקטעים של הטנזורים (או של מאפיין ספציפי של הטנזורים) שלא צוינו. הוא חוצה את זרימת הנתונים (שרשראות של הגדרות שימוש) של תרשים החישוב בשני הכיוונים עד שמגיעים לנקודה קבועה, כלומר, כבר אי אפשר לשנות את החלוקה לחלקים בלי לבטל את ההחלטות הקודמות לגבי חלוקה לחלקים.

אפשר לפרק את ההעברה לשלבים. בכל שלב בודקים פעולה ספציפית ומפיצים אותה בין טינסורים (אופרטנדים ותוצאות), על סמך המאפיינים של הפעולה. לדוגמה, אם מדובר במכפלת מטריצות, נפיץ בין המאפיין שאינו מתכווץ של lhs או של rhs למאפיין התואם של התוצאה, או בין המאפיין המתכווץ של lhs ושל rhs.

המאפיינים של פעולה קובעים את הקשר בין המאפיינים התואמים בקלט ובפלט שלה, וניתן להכליל אותם ככלל חלוקה (sharding) לכל פעולה.

בלי פתרון התנגשויות, שלב ההעברה (propagation) פשוט יעביר כמה שיותר נתונים תוך התעלמות מהצירים שנמצאים בהתנגשות. אנחנו מתייחסים לזה בתור צירי הפיצול הראשיים (הארוכים ביותר) התואמים.

תכנון מפורט

היררכיית פתרון התנגשויות

אנחנו משלבים כמה אסטרטגיות לפתרון סכסוכים בהיררכיה:

  1. עדיפויות מוגדרות על ידי משתמש. במאמר Sharding Representation, תיארנו איך אפשר לצרף תעדוף לפי חלוקה של מאפיינים כדי לאפשר חלוקה מצטברת של התוכנית, למשל, ביצוע מקביליות באצווה –>‏ megatron –>‏ חלוקה של ZeRO. כדי לעשות זאת, אנחנו מחילים את ההעברה (propagation) בחזרות (iterations) – בחזרה i אנחנו מעבירים את כל חלוקות המאפיינים שיש להן עדיפות <=i ומתעלים מכל שאר החלוקות. אנחנו גם מוודאים שההפצה לא תשנה חלוקות (shards) מוגדרות על ידי משתמשים עם עדיפות נמוכה יותר (>i), גם אם התעלמו מהן במחזורים קודמים.
  2. עדיפויות מבוססות-פעולה. אנחנו מפיצים את החלוקה לחלקים, על סמך סוג הפעולה. לפעולות 'מעבר דרך' (למשל, פעולות לפי רכיבים ושינוי צורה) יש את העדיפות הגבוהה ביותר, ולפעולות עם טרנספורמציה של צורה (למשל, dot ו-reduce) יש עדיפות נמוכה יותר.
  3. העברה אגרסיבית להפיץ את החלוקה לפלחים באמצעות אסטרטגיה אגרסיבית. האסטרטגיה הבסיסית מפיצה רק חלוקות ללא התנגשויות, ואילו האסטרטגיה האגרסיבית פותרת את ההתנגשויות. רמת אגרסיביות גבוהה יותר יכולה לצמצם את טביעת הרגל של הזיכרון, אבל על חשבון תקשורת פוטנציאלית.
  4. העברה בסיסית זוהי אסטרטגיית ההעברה הנמוכה ביותר בהיררכיה, שלא מבצעת פתרון של התנגשויות, ובמקום זאת מעבירה צירים שתואמים לכל המשתנים והתוצאות.

היררכיית ההעברה, שמציגה 4 סטאקים, מלמטה למעלה, עם התוויות הבאות: Basic Propagation,‏ Aggressive Propagation,‏ Operation Priority Propagation ו-User Priority Propagation.

אפשר לפרש את ההיררכיה הזו כמחזורים בתצוגת עץ. לדוגמה, לכל עדיפות משתמש מוחל העברה מלאה של עדיפות הפעולה.

כלל חלוקה של פעולות

כלל הפיצול מאפשר ליצור הפשטה של כל פעולה, שמספקת לאלגוריתם ההעברה בפועל את המידע הדרוש להעברת הפיצולים מאופרטורים לתוצאות או בין אופרטורים, וכו', בלי צורך להסיק מסקנות לגבי סוגי פעולות ספציפיים והמאפיינים שלהם. בעיקרון, מדובר בהוצאה מהמשוואה של הלוגיקה הספציפית לפעולה, ומתן ייצוג משותף (מבנה נתונים) לכל הפעולות למטרות העברה בלבד. בצורתה הפשוטה ביותר, היא מספקת רק את הפונקציה הזו:

GetOpShardingRule(Operation *) -> OpShardingRuleAttr

הכלל מאפשר לנו לכתוב את אלגוריתם ההפצה רק פעם אחת באופן כללי שמבוסס על מבנה הנתונים הזה (OpShardingRule), במקום לשכפל קטעי קוד דומים במספר רב של פעולות, וכך לצמצם באופן משמעותי את האפשרות לבאגים או להתנהגות לא עקבית בין פעולות.

נחזור לדוגמה של matmul.

אפשר לכתוב קידוד שמכיל את המידע הנדרש במהלך ההעברה, כלומר היחסים בין המאפיינים, בצורת סימון einsum:

(i, k), (k, j) -> (i, j)

בקידוד הזה, כל מאפיין ממופה לגורם יחיד.

איך ההעברה משתמשת במיפוי הזה: אם מאפיין של אופרטור/תוצאה מחולק לפלחים לאורך ציר, ההעברה תבדוק את הגורם של המאפיין הזה במיפוי הזה, ותחלק אופרטורים/תוצאות אחרים לפי המאפיין המתאים שלהם עם אותו גורם – ועשויה גם לשכפל אופרטורים/תוצאות אחרים שאין להם את הגורם הזה לאורך הציר הזה (בהתאם לדיון הקודם על רפליקה).

גורמים מורכבים: הרחבת הכלל לשינוי צורת הנתונים

בפעולות רבות, למשל matmul, צריך למפות כל מאפיין רק לגורם אחד. עם זאת, הוא לא מספיק לשינוי צורות.

הפונקציה הבאה לשינוי הצורה משלבת שני מאפיינים למאפיין אחד:

%out = mhlo.reshape(%in) : (tensor<2x4x32xf32>) -> tensor<8x32xf32>

כאן, המאפיינים 0 ו-1 של הקלט תואמים למאפיין 0 של הפלט. נניח שאנחנו מתחילים על ידי מתן גורמים לקלט:

(i,j,k) : i=2, j=4, k=32

אפשר לראות שאם רוצים להשתמש באותם גורמים בפלט, צריך מאפיין אחד שיפנה לכמה גורמים:

(i,j,k) -> ((ij), k) : i=2, j=4, k=32

אפשר לעשות את אותו הדבר אם שינוי הצורה גורם לפיצול של מאפיין:

%out = mhlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32> ((ij), k) -> (i,j,k) : i=2, j=4, k=32

המאפיין בגודל 8 מורכב בעיקר מהגורמים 2 ו-4, ולכן אנחנו קוראים לגורמים האלה גורמים מסוג (i,j,k).

הגורמים האלה יכולים לפעול גם במקרים שבהם אין מאפיין מלא שתואם לאחד מהגורמים:

%out = mhlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32> ((ij), k) -> (i,(jk)) : i=2, j=4, k=4

הדוגמה הזו מדגישה גם למה אנחנו צריכים לאחסן את גדלי הגורמים – כי אי אפשר להסיק אותם בקלות מהמאפיינים התואמים.

אלגוריתם העברה ליבה

העברת חלוקות לפי גורמים

ב-Shardy יש היררכיה של טינסורים, מאפיינים וגורמים. הם מייצגים נתונים ברמות שונות. גורם הוא מאפיין משני. זוהי היררכיה פנימית שמשמשת להפצה של חלוקה למחיצות. כל מאפיין יכול להתאים לגורם אחד או יותר. המיפוי בין המאפיין לגורם מוגדר על ידי OpShardingRule.

סכימה שמציגה את אלגוריתם ההפצה של Shardy.

Shardy מעביר את צירי הפיצול לפי גורמים במקום לפי מאפיינים. כדי לעשות זאת, יש לנו שלושה שלבים כפי שמוצג באיור הבא

  1. מעבר מ-DimSharding ל-FactorSharding בפרויקט
  2. הפיכת צירי חלוקה לזמינים במרחב של FactorSharding
  3. יצירת פרויקט של FactorSharding המעודכן כדי לקבל את DimSharding המעודכן

סכימה שמראה את ההפצה של חלוקה לקטעים ב-FactorSharding וב-DimSharding.

תצוגה חזותית של ההפצה של חלוקת שטחי האחסון לפי גורמים

נשתמש בטבלה הבאה כדי להמחיש את הבעיה ואת האלגוריתם של ההפצה באמצעות חלוקה לקטעים.

F0 F1 F2 צירים שמתבצעת להם רפליקה באופן מפורש
T0
T1
T2
  • כל עמודה מייצגת גורם. F0 הוא הגורם עם המדד 0. אנחנו מפיצים את החלוקה לגורמים (עמודות).
  • כל שורה מייצגת טינסור. T0 מתייחס לטרנספורמר עם המדד 0. טינסורים הם כל המשתנים והתוצאות שקשורים לפעולה ספציפית. לא ניתן לחפות בין צירים בשורה. אי אפשר להשתמש בציר (או בציר משנה) כדי לפצל טינסור אחד פעמים רבות. אם ציר מסוים מוכפל באופן מפורש, אי אפשר להשתמש בו כדי לחלק את הטנזור.

לכן, כל תא מייצג חלוקה של גורם. יכול להיות שגורם חסר בטנסורים חלקיים. בטבלה שבהמשך מופיע הערך של C = dot(A, B). התאים שמכילים את הערך N מצביעים על כך שהגורם לא נמצא בטנסור. לדוגמה, F2 נמצא ב-T1 וב-T2, אבל לא ב-T0.

C = dot(A, B) עמעום בקבוצות (batching) ברמה F0 F1 Non-contracting dim F2 Non-contracting dim F3 Contracting dim צירים שמתבצעת להם רפליקה באופן מפורש
T0 = A לא
T1 = B לא
T2 = C לא

איסוף והפצה של צירי חלוקה

כדי להמחיש את ההעברה, נשתמש בדוגמה פשוטה שמופיעה בהמשך.

F0 F1 F2 צירים שמתבצעת להם רפליקה באופן מפורש
T0 'a' ‎"f"‎
T1 "a",‏ "b" ‎"c", "d"‎ 'g'
T2 ‎"c", "e"‎

שלב 1. חיפוש צירים להעברה לאורך כל גורם (כלומר, צירי הפיצול הראשיים (הארוכים ביותר) שתואמים). בדוגמה הזו, אנחנו מעבירים את ["a", "b"] לאורך F0, מעבירים את ["c"] לאורך F1 ולא מעבירים דבר לאורך F2.

שלב 2. מרחיבים את חלוקות הגורמים כדי לקבל את התוצאה הבאה.

F0 F1 F2 צירים שמתבצעת להם רפליקה באופן מפורש
T0 'a', ‏ 'b' "c" ‎"f"‎
T1 "a",‏ "b" ‎"c", "d"‎ 'g'
T2 "a", "b" ‎"c", "e"‎