Compiler API

רקע

אנחנו מניחים שהקוראים מכירים לפחות את העקרונות הבסיסיים של ייצוג חלוקה, שמתארים איך אפשר לבטא את חלוקת הטנזור ב-Shardy. במסמך הזה מוסבר איך אפשר להשתמש בייצוגים של חלוקה לפלחים בתוכנית, למשל כדי לצרף חלוקה לפלחים לטרנספורמר ספציפי בתוכנית.

התפשטות של חלוקה למקטעים היא התהליך שבו קובעים חלוקה למקטעים לכל טינסור בתוכנית, בהתאם למגבלות חלוקה למקטעים של קבוצת משנה של הטינסורים. ב-Shardy יש כמה דרכים להשפיע על ההפצה של חלוקת המידע לקטעים או לשלוט בה. בנוסף, היא מאפשרת למשתמשים להוסיף לתוכניות שלהם חישובים שמחולקים באופן ידני למקטעים.

מטרה

במסמך הזה מתוארים העיצוב של רכיבי ה-API האלה ב-Shardy וההתנהגות והקבועים שלהם. חשוב לזכור ש-API הזה משמש לצורך בקרה על ההפצה של חלוקת המחיצות, אבל המאמר הזה לא ידון בהתנהגות של ההפצה או בתכנון שלה.

סקירה כללית

  • חלוקה לפלחים של קלט/פלט – מחברים חלוקה לפלחים לקלט או לפלט של הפונקציה הראשית, כדי לציין שכך צריך לפצל את הטנסור של הקלט/הפלט כשנותנים אותו לפונקציה או מקבלים אותו ממנה.

  • אילוץ חלוקה לפלחים – מחברים חלוקה לפלחים לטרנספורמציה ביניים (למשל, התוצאה של matmul) כדי לציין איך צריך לפצל את הטרנספורמציה הזו, או קבוצת משנה של השימושים שלה.

  • קבוצת חלוקה – קיבוץ של מספר טינסורים לפי מזהה כדי לציין שצריך לפצל אותם באותו אופן.

  • חישוב ידני – מקיף חישוב משנה שמחולק באופן ידני באמצעות קבוצת משנה של צירי רשת, כאשר החלוקות לצירים הידניים האלה מצוינות לכל הקלטות והפלטות, ובתוך החישוב המשנה סוגי הטנזורים הם מקומיים ביחס לחלוקות האלה.

תכנון מפורט

חלוקות של קלט ופלט

מאפשרת למשתמשים לציין חלוקה לפלחים של הקלט והפלט של הפונקציה הראשית.

ב-MLIR אפשר לצרף מאפיינים לארגומנטים ולתוצאות של פונקציות, ולכן המשתמשים יכולים לצרף מאפייני חלוקה לפונקציה בדרך הזו.

לדוגמה:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

אילוץ פיצול

מאפשרת למשתמשים לצרף חלוקה לענפים (sharding) לטנסור ביניים בתוכנית שלהם, כדי להורות למחלק לפלחים איך צריך לפצל את הטנסור הזה או קבוצת משנה של השימושים שלו.

זוהי פעולת MLIR שמקבלת את הטנזור כקלט, ויש לה מאפיין חלוקה (sharding) שמצורף אליה. הפעולה יכולה:

  • אין להם שימוש (dangling) – כלומר, חלוקת המשנה המצורפת היא האופן שבו צריך לפצל את הטנזור עצמו.
  • יש שימושים – כלומר, חלוקת המשנה המצורפת היא האופן שבו צריך לפצל את השימושים של אופרטורי האילוצים של חלוקת המשנה, בעוד שלשימושים אחרים של הטנזור הקלט עשויה להיות חלוקת משנה שונה (אם אין שימושים אחרים לטנזור הקלט, ההתנהגות זהה לתרחיש 'אין שימושים'). ההעברה תגדיר את חלוקת הטנסור עצמו, ותבצע חלוקה מחדש אם יהיה צורך.

יכול להיות שיש לו חלוקות פתוחים של מאפיינים, כלומר אפשר לפצל את המשתנה עוד יותר לפי צירים זמינים.

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

קבוצת חלוקה לקטעים

במקרים שבהם אין יחסי תלות בין שני טינסורים או יותר, או שאין יחסי תלות חזקים בין שני טינסורים או יותר, אבל המשתמשים יודעים שצריך לפצל את הטינסורים האלה באותו אופן או באופן דומה, ה-API של Shardy מספק דרך לציין את הקשר הזה. כך המשתמשים יכולים לציין באופן מפורש שצריך לפצל את הטנזורים זה לזה.

כדי לעשות זאת, אנחנו משיקים את הרעיון של קבוצות שרצים, שבכל קבוצה יש מספר בלתי מוגבל של הוראות שמשויכות לאותו מזהה קבוצת שרצים. קבוצות של חלוקה לקטעים אוכפות את החלוקה לקטעים באותה קבוצה.

לדוגמה, בתוכנית משתמש היפותטית כמו זו שמוצגת בהמשך, אנחנו רוצים לפצל את הפלט של התוכנית בדיוק כמו הקלט של התוכנית, בלי יחסי תלות בין השניים.

אם נריץ את התוכנית הזו, ההפצה של חלוקת המשנה לא תוכל להסיק על חלוקת המשנה של הטנסורים %1 ו-%2, והם ייוצרו כעותקים. עם זאת, אם מצרפים מאפיין shard_group שמציין שהקלט %0 והפלט %2 נמצאים באותו shard_group, אפשר להעביר את הפיצול @mesh_xy, [{"x"},{"y"}]> מהקלט %0 לפלט %2, ובתגובה לשאר התרשים, שמשודר כאן כקבוע %1. אפשר להקצות ערך לקבוצה באמצעות הפעולה sdy.sharding_group.

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

בדוגמה הפשוטה שלמעלה, אפשר היה לציין במקום זאת את אותו חלוקה לפלחים (shard) של הפלט כמו של הקלט, וכך להשיג את אותו אפקט, כי כבר ידענו מראש איזה פלחים אנחנו רוצים להקצות לקלט. אבל במקרים ריאליסטיים יותר, אנחנו משתמשים בחלוקה לפלחים כדי לשמור על סנכרון של חלוקה לפלחים של כמה טינסורים, בלי לדעת בהכרח את החלוקה לפלחים של אף אחד מהם, בעוד ש-Shardy יטפל בשאר הדברים וימצא את החלוקה לפלחים הטובה ביותר להקצות להם.

חישוב ידני

יכול להיות שהמשתמשים ירצו לשלוט באופן מפורש באופן שבו חלקים מהחישוב שלהם מחולקים למחיצות, ובאילו קבוצות שיתופיות נעשה שימוש. לדוגמה, משתמשים מסוימים רוצים להחיל matmul קולקטיבי באופן ידני (מ-API הקצה הקדמי) במקום להעביר את הטיפול למהדר. אנחנו מספקים להם את Manual Computation API שמאפשר להם לעשות זאת.

זוהי פעולת MLIR עם אזור אחד לחישוב המשנה הידני. המשתמשים יצטרכו לציין את חלוקת המשנה של הקלט/הפלט לחישוב המשני הזה באמצעות קבוצת משנה (כולל אולי את כולם) של צירי הרשת. החישוב המשני יהיה מקומי/ידני ביחס לצירים של הרשת שצוינו (נקראים גם צירים ידניים), וגלובלי/ללא חלוקה ביחס לצירים שלא צוינו (נקראים גם צירים חופשיים). אפשר לפצל את החישוב המשני לאורך הצירים החופשיים במהלך ההעברה, באותו אופן שבו אפשר לפצל את החישוב מחוץ לפעולה הזו.

לדוגמה:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Invariants

  1. כל הערכים של in_shardings, ‏ out_shardings ו-manual_axes חייבים להתייחס לאותה רשת. העמודה manual_axes ממוינת ביחס לרשת.

  2. צריך להשתמש ב-manual_axes באופן מפורש בכל החלוקות לפלחים (shards) של קלט/פלט, כלומר בכל חלוקה לפלחים, כל הצירים הידניים חייבים לחלק מאפיין לפלחים או לשכפל אותם באופן מפורש.

  3. אם יש ציר חופשי (כל ציר רשת שלא נמצא ב-manual_axes) באחד מהחלוקות לקטעים של קלט/פלט, הוא חייב להיות משני לכל ציר ידני באותו חלוקה לקטעים של מאפיינים (בדוגמה שלמעלה, חלוקה לקטעים של מאפיינים {"model", "data"} תהיה לא חוקית).

  4. האזור או הגוף של החישוב הוא החישוב המקומי (למשל, כולל קבוצות קולקטיביות שהמשתמשים ציינו). הוא חייב להיות מקומי ביחס לחלוקה לפלחים (sharding) של נתונים נכנסים/יוצאים לפי צירים ידניים (ראו הערה למעלה).

עריכת חישובים ידניים בתוך חישובים אחרים

אפשר להטמיע כמה חישובים ידניים זה בתוך זה, כל עוד כל אחד מהם פועל על קבוצה ייחודית משלו של צירים ידניים.