מפרט StableHLO

‫StableHLO הוא קבוצת פעולות לפעולות ברמה גבוהה (HLO) במודלים של למידת מכונה (ML). ‫StableHLO פועל כשכבת ניוד בין מסגרות שונות של למידת מכונה (ML) ומהדרים של למידת מכונה: מסגרות של למידת מכונה שמפיקות תוכניות StableHLO תואמות למהדרים של למידת מכונה שצורכים תוכניות StableHLO.

המטרה שלנו היא לפשט ולהאיץ את פיתוח ה-ML על ידי יצירת יכולת פעולה הדדית רבה יותר בין מסגרות שונות של ML (כמו TensorFlow, ‏ JAX ו-PyTorch) ובין קומפיילרים של ML (כמו XLA ו-IREE). לכן, במסמך הזה מפורטות ההגדרות של שפת התכנות StableHLO.

המפרט הזה מחולק לשלושה חלקים עיקריים. בקטע Programs מוסבר על המבנה של תוכניות StableHLO, שמורכבות מפונקציות StableHLO, שבתורן מורכבות מאופרטורים של StableHLO. במבנה הזה, בקטע Ops מפורטת הסמנטיקה של פעולות נפרדות. בקטע Execution (הרצה) מוסבר על הסמנטיקה של כל הפעולות האלה שמופעלות יחד בתוכנית. בסוף, בקטע Notation (סימון) מוסבר על הסימון שבו נעשה שימוש לאורך המפרט.

כדי לראות את המפרט מגרסה קודמת של StableHLO, פותחים את המאגר בגרסה המתויגת הרלוונטית. לדוגמה, StableHLO v0.19.0 Spec. כדי לראות את השינויים שחלו בכל עדכון של גרסה משנית של StableHLO, אפשר לעיין ביומן הגרסאות ב-VhloDialect.td.

תוכניות

Program ::= {Func}

תוכניות StableHLO מורכבות ממספר כלשהו של פונקציות StableHLO. למטה מופיעה דוגמה לתוכנית עם פונקציה @main שיש לה 3 קלטים (%image, ‏ %weights ו-%bias) ופלט אחד. הגוף של הפונקציה כולל 6 פעולות.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

פונקציות

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

פונקציות StableHLO (שנקראות גם פונקציות בעלות שם) כוללות מזהה, קלט/פלט וגוף. בעתיד, אנחנו מתכננים להוסיף מטא-נתונים לפונקציות כדי לשפר את התאימות ל-HLO (‎#425,‏ ‎#626,‏ ‎#740,‏ ‎#744).

מזהים

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

מזהי StableHLO דומים למזהים בהרבה שפות תכנות, עם שני מאפיינים ייחודיים: 1) לכל המזהים יש סיגילים שמבחינים בין סוגים שונים של מזהים, 2) מזהי ערכים יכולים להיות מספריים לחלוטין כדי לפשט את יצירת התוכניות של StableHLO.

סוגים

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

סוגי StableHLO מסווגים לסוגי ערכים (שנקראים גם סוגים ממחלקה ראשונה) שמייצגים ערכי StableHLO, ולסוגים שהם לא ערכים שמתארים רכיבים אחרים של התוכנית. הסוגים ב-StableHLO דומים לסוגים בהרבה שפות תכנות, והמאפיין העיקרי שמייחד אותם הוא האופי הספציפי לתחום של StableHLO, שמוביל לתוצאות לא שגרתיות (למשל, סוגים סקלריים הם לא סוגי ערכים).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

סוגי טנסורים מייצגים טנסורים, כלומר מערכים רב-ממדיים. יש להם צורה וסוג רכיב, כאשר צורה מייצגת גדלים של ממדים לא שליליים או לא ידועים בסדר עולה של הממדים המתאימים (שנקראים גם צירים) שממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא דרגה. לדוגמה, tensor<2x3xf32> הוא סוג טנסור עם צורה 2x3 וסוג אלמנט f32. יש לו שני מאפיינים (או במילים אחרות, שני צירים) – מאפיין 0 ומאפיין 1 – שהגדלים שלהם הם 2 ו-3. הדירוג שלו הוא 2.

יכול להיות שצורות יהיו לא ידועות באופן חלקי או מלא (דינמיות), למשל tensor<?x2xf64> לא ידוע באופן חלקי ו-tensor<?x?xf64> לא ידוע באופן מלא. גדלים של מאפיינים דינמיים מיוצגים באמצעות ?. אי אפשר לבטל את הדירוג של צורות.

בעתיד, אנחנו מתכננים לבדוק אפשרות להרחיב את סוגי הטנסורים מעבר לגדלים של ממדים ולסוגי אלמנטים, למשל לכלול פריסות (#629) ודלילות (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
שם סוג מגבלות
storage_type סוג מספר שלם (C1-C3), (C8)
storage_min קבוע מספרי שלם (C1), (C3), (C7)
storage_max קבוע מספרי שלם (C2), (C3), (C7)
expressed_type סוג נקודה צפה ‫(C4)
quantization_dimension קבוע מספרי אופציונלי ‫(C10-C12)
scales מספר משתנה של קבועים בשיטת נקודה צפה (C4-C6), (C9), (C10), (C13)
zero_points מספר משתנה של קבועים שלמים (C7-C9)

סוגי רכיבים שעברו קוונטיזציה מייצגים ערכים שלמים של סוג אחסון בטווח storage_min עד storage_max (כולל) שתואמים לערכים של נקודה צפה של סוג מבוטא. עבור ערך שלם נתון i, אפשר לחשב את הערך התואם בנקודה צפה f באופן הבא: f = (i - zero_point) * scale, כאשר scale ו-zero_point נקראים פרמטרים של קוונטיזציה. המאפיינים storage_min ו-storage_max הם אופציונליים בתחביר, אבל ערכי ברירת המחדל שלהם הם min_value(storage_type) ו-max_value(storage_type) בהתאמה. יש מגבלות על סוגי רכיבים שעברו קוונטיזציה:

  • (C1) type(storage_min) = storage_type.
  • ‫(C2) type(storage_max) = storage_type.
  • ‫(C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • ‫(C4) type(scales...) = expressed_type.
  • ‫(C5) 0 < scales.
  • ‫(C6) is_finite(scales...).
  • ‫(C7) storage_min <= zero_points <= storage_max.
  • ‫(C8) type(zero_points...) = storage_type.
  • ‫(C9) size(scales) = size(zero_points).
  • ‫(C10) אם is_empty(quantization_dimension), אז size(scales) = 1.
  • ‫(C11) 0 <= quantization_dimension.

בשלב הזה, QuantizationScale הוא קבוע בשיטת נקודה צפה, אבל יש עניין רב בסולמות שמבוססים על מספרים שלמים, שמיוצגים באמצעות מכפילים והזזות. אנחנו מתכננים לבדוק את האפשרות הזו בעתיד הקרוב (#1404).

מתנהל דיון על הסמנטיקה של QuantizationZeroPoint, כולל הסוג, הערכים והשאלה אם יכולה להיות רק נקודת אפס אחת או כמה נקודות אפס פוטנציאליות בסוג של טנסור שעבר קוונטיזציה. על סמך התוצאות של הדיון הזה, יכול להיות שהמפרט לגבי אפס נקודות ישתנה בעתיד (#1405).

דיון נוסף מתנהל לגבי הסמנטיקה של QuantizationStorageMin ושל QuantizationStorageMax כדי לקבוע אם צריך להגביל את הערכים האלה ואת הערכים של טנסורים שעברו קוונטיזציה (‎#1406).

בסוף, אנחנו מתכננים לבדוק איך לייצג סולמות לא ידועים ונקודות אפס, בדומה לאופן שבו אנחנו מתכננים לבדוק איך לייצג גדלים לא ידועים של מאפיינים (#1407).

סוגי טנסורים שעברו קוונטיזציה מייצגים טנסורים עם אלמנטים שעברו קוונטיזציה. הטנסורים האלה זהים בדיוק לטנסורים רגילים, אלא שהאלמנטים שלהם הם מסוג אלמנטים שעברו קוונטיזציה, במקום סוג אלמנטים רגיל.

בטנסורים שעברו קוונטיזציה, הקוונטיזציה יכולה להיות per-tensor, כלומר, עם scale ו-zero_point אחדים לכל הטנסור, או per-axis, כלומר, עם כמה scales ו-zero_points, זוג אחד לכל פרוסה של מימד מסוים quantization_dimension. באופן רשמי יותר, בטנזור t עם כימות לכל ציר, יש dim(t, quantization_dimension) פרוסות של quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] וכו'. כל הרכיבים בפרוסה ה-i משתמשים ב-scales[i] וב-zero_points[i] כפרמטרים של הכימות. יש מגבלות על סוגי טנסורים שעברו קוונטיזציה:

  • לכימות לכל טנסור:
    • אין אילוצים נוספים.
  • לכימות לכל ציר:
    • (C12) quantization_dimension < rank(self).
    • ‫(C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

סוגי טוקנים מייצגים טוקנים, כלומר ערכים אטומים שנוצרים ונצרכים על ידי פעולות מסוימות. האסימונים משמשים להטלת סדר ביצוע על פעולות, כפי שמתואר בקטע ביצוע.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

סוגי מאגרים מייצגים מאגרים. לדוגמה, ב-XLA, מאגרי נתונים זמניים הם מערכים רב-ממדיים עם אחסון עקבי. בדומה לסוגי טנסורים, לסוגי מאגרים יש צורה וסוג רכיב, כאשר הצורה מייצגת גדלים של ממדים לא שליליים או לא ידועים בסדר עולה של הממדים המתאימים (שנקראים גם צירים) שממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא דרגה. לדוגמה, memref<2x3xf32> הוא סוג של מאגר עם צורה 2x3 וסוג רכיב f32. יש לו שני ממדים (או במילים אחרות, שני צירים) – ממד 0 וממד 1 – שהגדלים שלהם הם 2 ו-3. הדירוג שלו הוא 2.

אפשר להקצות מאגרי נתונים באמצעות custom_call עד CreateBuffer או Pin, ולבטל את ההקצאה באמצעות custom_call עד Unpin. רק פעולות custom_call יכולות לקרוא ולכתוב את התוכן בתוך מאגרי נתונים זמניים. פרטים נוספים מופיעים במאמר בנושא custom_call.

סוגי טאפלים מייצגים טאפלים, כלומר רשימות הטרוגניות. הטופלים הם תכונה מדור קודם שקיימת רק לצורך תאימות ל-HLO. ב-HLO, נעשה שימוש בטפילים כדי לייצג קלטים ופלט משתנים. ב-StableHLO יש תמיכה מובנית בקלט ובפלט עם מספר משתנה של ארגומנטים, והשימוש היחיד בטפלים ב-StableHLO הוא לייצוג מקיף של HLO ABI, שבו למשל T, ‏ tuple<T> ו-tuple<tuple<T>> עשויים להיות שונים באופן משמעותי בהתאם להטמעה מסוימת. בעתיד, אנחנו מתכננים לבצע שינויים ב-HLO ABI, שיאפשרו לנו להסיר סוגי טאפל מ-StableHLO (‎#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

סוגי רכיבים מייצגים רכיבים של סוגי טנסורים. בניגוד לשפות תכנות רבות, הסוגים האלה לא מוגדרים כסוגים ממחלקה ראשונה ב-StableHLO. המשמעות היא שתוכניות StableHLO לא יכולות לייצג ישירות ערכים מהסוגים האלה (כתוצאה מכך, מקובל לייצג ערכים סקלריים מהסוג T באמצעות ערכים טנסוריים אפסיים מהסוג tensor<T>).

  • סוג בוליאני מייצג ערכים בוליאניים true ו-false.
  • סוגי מספרים שלמים יכולים להיות עם סימן (si) או ללא סימן (ui), ויכולים להיות בעלי אחת מרוחבי הסיביות הנתמכים (2,‏ 4,‏ 8,‏ 16,‏ 32 או 64). סוגי siN עם סימן מייצגים ערכים של מספרים שלמים מ--2^(N-1) עד 2^(N-1)-1 כולל, וסוגי uiN ללא סימן מייצגים ערכים של מספרים שלמים מ-0 עד 2^N-1 כולל.
  • סוגי נקודה צפה יכולים להיות אחד מהסוגים הבאים:
  • סוגים מורכבים מייצגים ערכים מורכבים שיש להם חלק ממשי וחלק מדומה מאותו סוג אלמנט. הסוגים המורכבים הנתמכים הם complex<f32> (שני החלקים הם מסוג f32) ו-complex<f64> (שני החלקים הם מסוג f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

סוגי פונקציות מייצגים פונקציות בעלות שם ופונקציות אנונימיות. יש להם סוגי קלט (רשימת הסוגים בצד ימין של ->) וסוגי פלט (רשימת הסוגים בצד שמאל של ->). בשפות תכנות רבות, סוגי פונקציות הם מהסוג הראשון, אבל לא ב-StableHLO.

StringType ::= 'string'

סוג מחרוזת מייצג רצפים של בייטים. בניגוד לשפות תכנות רבות, סוג המחרוזת לא מוגדר כסוג בסיסי ב-StableHLO, והוא משמש רק לציון מטא-נתונים סטטיים של רכיבי תוכנית.

תפעול

פעולות StableHLO (שנקראות גם ops) מייצגות קבוצה סגורה של פעולות ברמה גבוהה במודלים של למידת מכונה. כמו שצוין למעלה, התחביר של StableHLO מבוסס במידה רבה על MLIR, שהוא לא בהכרח החלופה הכי נוחה, אבל אפשר לטעון שהוא הכי מתאים למטרה של StableHLO – ליצור יכולת פעולה הדדית טובה יותר בין מסגרות ML וקומפיילרים של ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

לפעולות StableHLO (שנקראות גם ops) יש שם, קלט/פלט וחתימה. השם מורכב מהקידומת stablehlo. ומנימוניק שמזהה באופן ייחודי אחת מהפעולות הנתמכות. בהמשך מופיעה רשימה מקיפה של כל הפעולות הנתמכות.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

פעולות צורכות קלטים ומפיקות פלטים. הקלט מסווג לערכי קלט (מחושבים במהלך ההפעלה), פונקציות קלט (מסופקות באופן סטטי, כי בפונקציות StableHLO לא מוגדרים ערכים מסוג first-class) ולמאפייני קלט (מסופקים גם באופן סטטי). סוגי הקלט והפלט שפעולה צורכת ומייצרת תלויים בסימן שלה. לדוגמה, הפעולה add צורכת 2 ערכי קלט ומפיקה ערך פלט אחד. לעומת זאת, הפעולה select_and_scatter צורכת 3 ערכי קלט, 2 פונקציות קלט ו-3 מאפייני קלט.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

פונקציות קלט (שנקראות גם פונקציות אנונימיות) דומות מאוד לפונקציות בעלות שם, אבל יש ביניהן הבדלים: 1) אין להן מזהה (ומכאן השם 'אנונימיות'), 2) הן לא מצהירות על סוגי פלט (סוגי הפלט נגזרים מהאופרטור return בתוך הפונקציה).

התחביר של פונקציות הקלט כולל חלק שלא נמצא בשימוש כרגע (ראו את Unusedהייצור שלמעלה), והוא נועד לתאימות ל-MLIR. ב-MLIR, יש מושג כללי יותר של 'אזורים' שיכולים לכלול כמה 'בלוקים' של פעולות שמחוברים באמצעות פעולות קפיצה. לבלוקים האלה יש מזהים שתואמים לUnused הייצור, כדי שאפשר יהיה להבדיל ביניהם. ב-StableHLO אין אופרטורים של קפיצה, ולכן החלק התואם בתחביר של MLIR לא נמצא בשימוש (אבל הוא עדיין קיים).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

מאפייני קלט כוללים שם וערך שהוא אחת מהקבועים הנתמכים. הם הדרך העיקרית לציין מטא-נתונים סטטיים עבור רכיבי תוכנה. לדוגמה, האופרטור concatenate משתמש במאפיין dimension כדי לציין את המימד שלפיו ערכי הקלט שלו משורשרים. באופן דומה, האופרטור slice משתמש בכמה מאפיינים כמו start_indices ו-limit_indices כדי לציין את הגבולות שמשמשים לפילוח של ערך הקלט.

בשלב הזה, תוכניות StableHLO בשימוש נרחב מכילות לפעמים מאפיינים שלא מתוארים במסמך הזה. בעתיד, אנחנו מתכננים להוסיף את המאפיינים האלה ל-opset של StableHLO או לאסור את השימוש בהם בתוכניות של StableHLO. בינתיים, הנה רשימת המאפיינים האלה:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (מס' 619).
  • output_operand_aliases (#740).
  • מטא-נתונים של מיקום (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

חתימת הפעולה מורכבת מסוגי כל ערכי הקלט (רשימת הסוגים בצד ימין של ->) ומסוגי כל ערכי הפלט (רשימת הסוגים בצד שמאל של ->). למעשה, סוגי הקלט הם מיותרים, וגם סוגי הפלט הם כמעט תמיד מיותרים (כי ברוב הפעולות של StableHLO אפשר להסיק את סוגי הפלט מהקלט). עם זאת, חתימת האופרטור היא חלק מכוון בתחביר של StableHLO לצורך תאימות ל-MLIR.

בהמשך מופיעה דוגמה לאופרטור שהנימוניק שלו הוא select_and_scatter. היא צורכת 3 ערכי קלט (%operand, ‏ %source ו-%init_value), 2 פונקציות קלט ו-3 מאפייני קלט (window_dimensions, ‏ window_strides ו-padding). שימו לב שהחתימה של הפעולה כוללת רק את הסוגים של ערכי הקלט שלה (אבל לא את הסוגים של פונקציות הקלט והמאפיינים שמוגדרים בשורה).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

ערכים קבועים

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

לקבועים של StableHLO יש ערך מילולי וסוג, שביחד מייצגים ערך של StableHLO. בדרך כלל, הסוג הוא חלק מהתחביר הקבוע, אלא אם הוא חד-משמעי (למשל, קבוע בוליאני הוא חד-משמעי מסוג i1, בעוד שקבוע של מספר שלם יכול להיות מכמה סוגים אפשריים).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

קבועים בוליאניים מייצגים ערכים בוליאניים true ו-false. לקבועים בוליאניים יש סוג i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

קבועים של מספרים שלמים מייצגים ערכים של מספרים שלמים באמצעות מחרוזות שמשתמשות בסימון עשרוני או הקסדצימלי. אין תמיכה בבסיסים אחרים, למשל בינארי או אוקטלי. ההגבלות על קבועים מסוג מספר שלם הן:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

קבועים של נקודה צפה מייצגים ערכים של נקודה צפה באמצעות מחרוזות שמשתמשות בסימון עשרוני או מדעי. בנוסף, אפשר להשתמש בסימון הקסדצימלי כדי לציין ישירות את הביטים הבסיסיים בפורמט הנקודה הצפה של הסוג המתאים. יש מגבלות על קבועים של נקודה צפה:

  • ‫(C1) אם משתמשים בסימון לא הקסדצימלי, is_wellformed(float_literal, float_type).
  • ‫(C2) אם משתמשים בסימון הקסדצימלי, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

קבועים מרוכבים מייצגים ערכים מרוכבים באמצעות רשימות של חלק ממשי (מופיע ראשון) וחלק דמיוני (מופיע שני). לדוגמה, (1.0, 0.0) : complex<f32> מייצג את 1.0 + 0.0i, ו-(0.0, 1.0) : complex<f32> מייצג את 0.0 + 1.0i. הסדר שבו החלקים האלה נשמרים בזיכרון מוגדר בהטמעה. יש הגבלות על קבועים מורכבים:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • ‫(C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

קבועים של טנסור מייצגים ערכי טנסור באמצעות רשימות מקוננות שמוגדרות באמצעות סימון NumPy. לדוגמה, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> מייצג ערך טנזור עם המיפוי הבא מאינדקסים לרכיבים: {0, 0} => 1, ‏ {0, 1} => 2, ‏ {0, 2} => 3, ‏ {1, 0} => 4, ‏ {1, 1} => 5,‏ {1, 2} => 6. הסדר שבו הרכיבים האלה נשמרים בזיכרון מוגדר בהטמעה. ההגבלות הבאות חלות על קבועים של Tensor:

  • ‫(C1) has_syntax(tensor_literal, element_type(tensor_type)), כאשר:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • ‫(C2) has_shape(tensor_literal, shape(tensor_type)), כאשר:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • אחרת, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

קבועים של טנסור שעבר קוונטיזציה מייצגים ערכים של טנסור שעבר קוונטיזציה באמצעות אותה שיטה כמו קבועים של טנסור, עם אלמנטים שמוגדרים כקבועים של סוג האחסון שלהם. יש את המגבלות הבאות על קבועים של טנסור שעברו קוונטיזציה:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • ‫(C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

מחרוזות ליטרליות מורכבות מבייטים שצוינו באמצעות תווים בפורמט ASCII ורצפי escape. הן לא תלויות בקידוד, ולכן הפרשנות של הבייטים האלה מוגדרת בהטמעה. למחרוזות ליטרליות יש סוג string.

תפעול

מוחלט

סמנטיקה

מבצע פעולת abs לפי רכיבים בטנזור operand ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים עם סימן: מודולוס של מספר שלם.
  • למספרים ממשיים: abs מ-IEEE-754.
  • למספרים מרוכבים: מודולוס מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(abs, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור (C1-C2)

פלט

שם סוג מגבלות
result טנסור של מספר שלם עם סימן או מספר בשיטת נקודה צפה, או טנסור שעבר קוונטיזציה לכל טנסור (C1-C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • ‫(C2) baseline_element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) if is_complex(operand).
    • baseline_element_type(operand) אחרת.

דוגמאות

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]

 דוגמאות נוספות

הוספה

סמנטיקה

מבצעת חיבור של שני טנסורים lhs ו-rhs לפי רכיבים ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: OR לוגי.
  • למספרים שלמים: חיבור של מספרים שלמים.
  • למספרים ממשיים: addition מ-IEEE-754.
  • למספרים מרוכבים: חיבור מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(add, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור או טנזור שעבר קוונטיזציה (C1-C6)
(I2) rhs טנזור או טנזור שעבר קוונטיזציה (C1-C5), (C7)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1-C7)

מגבלות

  • אם הפעולה משתמשת בטנסורים לא מכומתים:
    • (C1) type(lhs) = type(rhs) = type(result).
  • אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
    • ‫(C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • ‫(C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • ‫(C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • ‫(C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • ‫(C6) אם is_per_axis_quantized(lhs), אז quantization_dimension(lhs) = quantization_dimension(result).
    • ‫(C7) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = quantization_dimension(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]

 דוגמאות נוספות

after_all

סמנטיקה

ההגדרה הזו מבטיחה שהפעולות שיוצרות את inputs יבוצעו לפני פעולות שתלויות ב-result. הפעולה הזו לא מבצעת שום דבר, היא קיימת רק כדי ליצור תלות בנתונים מ-result ל-inputs.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של token

פלט

שם סוג
result token

דוגמאות

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token

 דוגמאות נוספות

all_gather

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מחברת את הערכים של טנסורים operands מכל תהליך לאורך all_gather_dim ומפיקה טנסורים results.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • operands...@receiver = [operand@sender for sender in process_group] לכל receiver ב-process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) לכל process ב-process_group.

קלט

תווית שם סוג מגבלות
(I1) operands מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1), (C6)
(I2) all_gather_dim קבוע מהסוג si64 (C1), (C6)
(I3) replica_groups קבוע טנסור דו-ממדי מסוג si64 (C2-C4)
(I4) channel_id קבוע מהסוג si64 (C5)
(I5) use_global_device_ids קבוע מהסוג i1 (C5)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C6)

מגבלות

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • ‫(C2) is_unique(replica_groups).
  • ‫(C3) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • ‫(C4) 0 <= replica_groups < size(replica_groups).
  • ‫(C5) אם use_global_device_ids = true, אז channel_id > 0.
  • ‫(C6) type(results...) = type(operands...) except:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
  // channel_id = 0
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
  // use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 דוגמאות נוספות

all_reduce

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מחילה צמצום computation על ערכי הטנסורים operands מכל תהליך ומפיקה טנסורים results.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • results...@process[result_index] = exec(schedule) עבור עץ בינארי מסוים ‫schedule כאשר:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי שמוגדר בהטמעה, והמעבר שלו בסדר עולה הוא to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

קלט

תווית שם סוג מגבלות
(I1) operands מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C5), (C6)
(I2) replica_groups מספר משתנה של קבועי טנסור חד-ממדיים מסוג si64 (C1-C3)
(I3) channel_id קבוע מהסוג si64 ‫(C4)
(I4) use_global_device_ids קבוע מהסוג i1 ‫(C4)
(I5) computation פונקציה (C5)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C6-C7)

מגבלות

  • (C1) is_unique(replica_groups).
  • ‫(C2) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • ‫(C3) 0 <= replica_groups < size(replica_groups).
  • ‫(C4) אם use_global_device_ids = true, אז channel_id > 0.
  • ‫(C5) computation הוא מסוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • ‫(C6) shape(results...) = shape(operands...).
  • ‫(C7) element_type(results...) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  // channel_id = 0
  channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
  // use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 דוגמאות נוספות

all_to_all

סמנטיקה

all_to_all

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מפצלת את הערכים של טנסורים operands לאורך split_dimension לחלקים, מפזרת את החלקים המפוצלים בין התהליכים, משרשרת את החלקים המפוזרים לאורך concat_dimension ומפיקה טנסורים results. הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(replica_groups) if channel_id <= 0.
  • cross_partition(replica_groups) if channel_id > 0.

לאחר מכן, בכל process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) לכל sender ב-process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] כשreceiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

קלט

תווית שם סוג מגבלות
(I1) operands מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C3), (C9)
(I2) split_dimension קבוע מהסוג si64 (C1), (C2), (C9)
(I3) concat_dimension קבוע מהסוג si64 (C3), (C9)
(I4) split_count קבוע מהסוג si64 ‫(C2),‏ (C4),‏ (C8),‏ (C9)
(I5) replica_groups קבוע טנסור דו-ממדי מסוג si64 (C5-C8)
(I6) channel_id קבוע מהסוג si64

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C9)

מגבלות

  • (C1) 0 <= split_dimension < rank(operands...).
  • ‫(C2) dim(operands..., split_dimension) % split_count = 0.
  • ‫(C3) 0 <= concat_dimension < rank(operands...).
  • ‫(C4) 0 < split_count.
  • ‫(C5) is_unique(replica_groups).
  • ‫(C6) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • ‫(C7) 0 <= replica_groups < size(replica_groups).
  • ‫(C8) dim(replica_groups, 1) = split_count.
  • ‫(C9) type(results...) = type(operands...) למעט, אם split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
  // channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 דוגמאות נוספות

וגם

סמנטיקה

מבצעת פעולת AND לפי רכיבים של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: AND לוגי.
  • למספרים שלמים: ערך AND ברמת הביטים.

קלט

תווית שם סוג מגבלות
(I1) lhs טנסור מסוג בוליאני או מספר שלם (C1)
(I2) rhs טנסור מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result טנסור מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]

 דוגמאות נוספות

atan2

סמנטיקה

מבצע פעולת atan2 לפי רכיבים בטנזורים lhs ו-rhs ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: atan2 מ-IEEE-754.
  • למספרים מרוכבים: complex atan2.
  • לסוגים כמותיים: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)
(I2) rhs טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 דוגמאות נוספות

batch_norm_grad

סמנטיקה

מחשב את הגרדיאנטים של כמה תשומות של batch_norm_training באמצעות backpropagation מ-grad_output, ומפיק טנסורים של grad_operand, grad_scale ו-grad_offset. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

לסוגים כמותיים, הפונקציה מבצעת dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1-C3), (C5)
(I2) scale טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4), (C5)
(I3) mean טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4)
(I4) variance טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4)
(I5) grad_output טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C2), (C3)
(I6) epsilon קבוע מהסוג f32
(I7) feature_index קבוע מהסוג si64 (C1), (C5)

פלט

שם סוג מגבלות
grad_operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C2), (C3)
grad_scale טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4)
grad_offset טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • ‫(C2) operand, ‏ scale, ‏ mean, ‏ variance, ‏ grad_output, ‏ grad_operand,‏ grad_scale ו-grad_offset כוללים את אותו baseline_element_type.
  • ‫(C3) ל-operand, ל-grad_output ול-grad_operand יש את אותו הצורה.
  • ‫(C4) ל-scale, ‏ mean, ‏ variance, ‏ grad_scale ו-grad_offset יש את אותה צורה.
  • ‫(C5) size(scale) = dim(operand, feature_index).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
 <    tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

סמנטיקה

מבצעת נורמליזציה של טנסור operand בכל המימדים, למעט מימד feature_index, ומפיקה טנסור result. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1-C7)
(I2) scale טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C3)
(I3) offset טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C4)
(I4) mean טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C5)
(I5) variance טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור (C2), (C6)
(I6) epsilon קבוע מהסוג f32
(I7) feature_index קבוע מהסוג si64 (C1), (C3-C6)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C2), (C7)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • ‫(C2) operand, scale, offset, mean, variance ו-result כולן בעלות אותו baseline_element_type.
  • ‫(C3) size(scale) = dim(operand, feature_index).
  • ‫(C4) size(offset) = dim(operand, feature_index).
  • ‫(C5) size(mean) = dim(operand, feature_index).
  • ‫(C6) size(variance) = dim(operand, feature_index).
  • ‫(C7) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

סמנטיקה

מחשבת את הממוצע והשונות בכל המאפיינים, למעט המאפיין feature_index ומנרמלת את טנסור operand כדי ליצור את הטנסורים output, batch_mean ו-batch_var. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

לסוגים כמותיים, הפונקציה מבצעת dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)
(I2) scale טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור (C2), (C3)
(I3) offset טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור (C2), (C4)
(I4) epsilon קבוע מהסוג f32 (C1), (C3-C6)
(I5) feature_index קבוע מהסוג si64 (C1), (C3-C6)

פלט

שם סוג מגבלות
output טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C7)
batch_mean טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור (C2), (C5)
batch_var טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור (C2), (C6)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • ‫(C2) operand, ‏ scale, ‏ offset, ‏ batch_mean, ‏ batch_var ו-output כוללים את אותו baseline_element_type.
  • ‫(C3) size(scale) = dim(operand, feature_index).
  • ‫(C4) size(offset) = dim(operand, feature_index).
  • ‫(C5) size(batch_mean) = dim(operand, feature_index).
  • ‫(C6) size(batch_var) = dim(operand, feature_index).
  • ‫(C7) baseline_type(output) = baseline_type(operand).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
 <   (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

סמנטיקה

מבצע פעולת bitcast על טנסור operand ומפיק טנסור result שבו הביטים של כל הטנסור operand מפורשים מחדש באמצעות הסוג של הטנסור result.

באופן רשמי יותר, בהינתן E = element_type(operand), E' = element_type(result), ו-R = rank(operand):

  • אם num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • אם num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • אם num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

הפונקציה bits מחזירה ייצוג בזיכרון של ערך נתון, וההתנהגות שלה מוגדרת על ידי ההטמעה, כי הייצוג המדויק של טנסורים מוגדר על ידי ההטמעה, וגם הייצוג המדויק של סוגי רכיבים מוגדר על ידי ההטמעה.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C2)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1-C2)

מגבלות

  • ‫(C1) בהינתן E = is_quantized(operand) ? storage_type(operand) : element_type(operand), ‏E' = is_quantized(result) ? storage_type(result) : element_type(result) ו-R = rank(operand):
    • אם num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • אם num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • אם num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • ‫(C2) אם is_complex(operand) or is_complex(result), אז is_complex(operand) and is_complex(result).

דוגמאות

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 דוגמאות נוספות

broadcast_in_dim

סמנטיקה

broadcast_in_dim

מרחיב את המאפיינים או את הדירוג של טנסור קלט על ידי שכפול הנתונים בטנסור operand ומפיק טנסור result. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר לכל d ב- axes(operand):

  • operand_index[d] = 0 if dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C2), (C5-C6)
(I2) broadcast_dimensions קבוע טנסור חד-ממדי מסוג si64 (C2-C6)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1), (C3), (C5-C6)

מגבלות

  • ‫(C1) element_type(result) מחושב לפי הנוסחה הבאה:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand), למעט quantization_dimension(operand),‏ scales(operand) ו-zero_points(operand) שעשויים להיות שונים מ-quantization_dimension(result),‏ scales(result) ו-zero_points(result) בהתאמה, אחרת.
  • ‫(C2) size(broadcast_dimensions) = rank(operand).
  • ‫(C3) 0 <= broadcast_dimensions < rank(result).
  • ‫(C4) is_unique(broadcast_dimensions).
  • ‫(C5) לכל d בaxes(operand):
    • dim(operand, d) = 1 או
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • ‫(C6) אם is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • אם dim(operand, quantization_dimension(operand)) = 1, אז scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

דוגמאות

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 דוגמאות נוספות

כיסוי

סמנטיקה

הפונקציה מפיקה את הפלט מהרצה של פונקציה אחת בדיוק מתוך branches בהתאם לערך של index. באופן רשמי יותר, result = selected_branch() כאשר:

  • selected_branch = branches[index] if 0 <= index < size(branches).
  • selected_branch = branches[-1] אחרת.

קלט

תווית שם סוג מגבלות
(I1) index טנזור אפס-ממדי מסוג si32
(I2) branches מספר משתנה של פונקציות (C1-C4)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה ‫(C4)

מגבלות

  • (C1) 0 < size(branches).
  • ‫(C2) input_types(branches...) = [].
  • ‫(C3) same(output_types(branches...)).
  • ‫(C4) type(results...) = output_types(branches[0]).

דוגמאות

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
  "stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]

 דוגמאות נוספות

cbrt

סמנטיקה

מבצעת פעולת שורש שלישי לפי רכיבים על טנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: rootn(x, 3) מ-IEEE-754.
  • למספרים מרוכבים: שורש מעוקב מרוכב.
  • לסוגים שעברו קוונטיזציה: dequantize_op_quantize(cbrt, operand, type(result))

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]

 דוגמאות נוספות

ceil

סמנטיקה

מבצעת פעולת ceil על כל רכיב בטנזור operand ומפיקה טנזור result. מיישמת את הפעולה roundToIntegralTowardPositive מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(ceil, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 דוגמאות נוספות

cholesky

סמנטיקה

מחשבת את פירוק Cholesky של קבוצת מטריצות.

באופן רשמי יותר, לכל i ב-index_space(result),‏ result[i0, ..., iR-3, :, :] הוא פירוק שולסקי של a[i0, ..., iR-3, :, :], בצורה של מטריצה משולשת תחתונה (אם lower הוא true) או מטריצה משולשת עליונה (אם lower הוא false). ערכי הפלט במשולש הנגדי, כלומר במשולש העליון או במשולש התחתון, בהתאמה, מוגדרים על ידי ההטמעה.

אם קיים i שבו מטריצת הקלט היא לא מטריצה הרמיטית חיובית מוגדרת, ההתנהגות לא מוגדרת.

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1-C3)
(I2) lower קבוע מהסוג i1

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(a) = baseline_type(result).
  • ‫(C2) 2 <= rank(a).
  • ‫(C3) dim(a, -2) = dim(a, -1).

דוגמאות

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

תיחום

סמנטיקה

הפונקציה מחזירה טנזור result שבו כל רכיב בטנזור operand מוגבל לערך מינימלי ומקסימלי. באופן רשמי יותר, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), כאשר min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(clamp, min, operand, max, type(result)).

הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560).

קלט

תווית שם סוג מגבלות
(I1) min tensor או per-tensor quantized tensor (C1), (C3)
(I2) operand tensor או per-tensor quantized tensor (C1-C4)
(I3) max tensor או per-tensor quantized tensor (C2), (C3)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor ‫(C4)

מגבלות

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • ‫(C2) rank(max) = 0 or shape(max) = shape(operand).
  • ‫(C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • ‫(C4) baseline_type(operand) = baseline_type(result).

דוגמאות

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]

 דוגמאות נוספות

collective_broadcast

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הערך של טנזור operand נשלח מתהליך המקור לתהליכי היעד ונוצר טנזור result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(replica_groups) if channel_id <= 0.
  • cross_partition(replica_groups) if channel_id > 0.

לאחר מכן, result@process מחושב לפי הנוסחה הבאה:

  • operand@process_groups[i, 0] אם קיים i כך שהתהליך נמצא ב-process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C3)
(I2) replica_groups מספר משתנה של קבועי טנסור חד-ממדיים מסוג si64 (C1), (C2)
(I3) channel_id קבוע מהסוג si64

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C3)

מגבלות

  • (C1) is_unique(replica_groups).
  • ‫(C2) 0 <= replica_groups < N כאשר N מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • ‫(C3) type(result) = type(operand).

דוגמאות

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הערך של טנזור operand נשלח מתהליך המקור לתהליך היעד ונוצר טנזור result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(source_target_pairs) if channel_id <= 0.
  • cross_partition(source_target_pairs) if channel_id > 0.

לאחר מכן, result@process מחושב לפי הנוסחה הבאה:

  • operand@process_groups[i, 0], אם קיים i כך ש- process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C5)
(I2) source_target_pairs קבוע טנסור דו-ממדי מסוג si64 (C1-C4)
(I3) channel_id קבוע מהסוג si64

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1)

מגבלות

  • (C1) dim(source_target_pairs, 1) = 2.
  • ‫(C2) is_unique(source_target_pairs[:, 0]).
  • ‫(C3) is_unique(source_target_pairs[:, 1]).
  • ‫(C4) 0 <= source_target_pairs < N, כאשר N מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • ‫(C5) type(result) = type(operand).

דוגמאות

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

 דוגמאות נוספות

השוואה

סמנטיקה

מבצע השוואה בין רכיבים של טנסורים lhs ו-rhs בהתאם ל-comparison_direction ול-compare_type, ומפיק טנסור result.

המשמעות של הערכים comparison_direction ו-compare_type היא כדלקמן:

לסוגי רכיבים בוליאניים ומספריים שלמים:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

עבור סוגי רכיבים של נקודה צפה עם compare_type = FLOAT, האופרטור מטמיע את הפעולות הבאות של IEEE-754:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

עבור סוגי רכיבים של נקודה צפה עם compare_type = TOTALORDER, האופרטור משתמש בשילוב של פעולות totalOrder ו-compareQuietEqual מ-IEEE-754.

לסוגי רכיבים מורכבים, מתבצעת השוואה לקסיקוגרפית של זוגות (real, imag) באמצעות comparison_direction ו-compare_type שסופקו. הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים כש-comparison_direction הוא GE, ‏ GT, ‏ LE או LT (מספר 560).

לסוגים כמותיים. הפונקציה מבצעת dequantize_compare(lhs, rhs, comparison_direction).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1-C3)
(I2) rhs tensor או per-tensor quantized tensor (C1-C2)
(I3) comparison_direction enum of EQ, NE, GE, GT, LE, and LT
(I4) compare_type enum of FLOAT, TOTALORDER, SIGNED, and UNSIGNED (C3)

פלט

שם סוג מגבלות
result tensor of boolean type (C2)

מגבלות

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • ‫(C2) shape(lhs) = shape(rhs) = shape(result).
  • ‫(C3) compare_type מוגדר כך:
    • SIGNED if is_signed_integer(element_type(lhs)).
    • UNSIGNED if is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT או TOTALORDER אם is_float(element_type(lhs)).
    • FLOAT if is_complex(element_type(lhs)).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = <#stablehlocomparison_di>rection LT,
  compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]

 דוגמאות נוספות

מורכב

סמנטיקה

מבצעת המרה של כל רכיב בנפרד לערך מרוכב מתוך זוג של ערכים ממשיים ומדומים, lhs ו-rhs, ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טנסור מסוג f32 או f64 (C1-C3)
(I2) rhs טנסור מסוג f32 או f64 (C1)

פלט

שם סוג מגבלות
result טנסור מסוג מורכב (C2), (C3)

מגבלות

  • (C1) type(lhs) = type(rhs).
  • ‫(C2) shape(result) = shape(lhs).
  • ‫(C3) element_type(result) הוא מסוג complex<E> כאשר E = element_type(lhs).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]

 דוגמאות נוספות

מורכב

סמנטיקה

פעולה שמורכבת מפעולות אחרות של StableHLO, מקבלת inputs ו-composite_attributes ומפיקה results. הסמנטיקה של הפעולה מיושמת באמצעות המאפיין decomposition. אפשר להחליף את הפעולה composite בפירוק שלה בלי לשנות את הסמנטיקה של התוכנית. במקרים שבהם הוספת הפירוק לשורה לא מספקת את אותה סמנטיקה של פעולה, עדיף להשתמש ב-custom_call.

השדה version (ברירת המחדל היא 0) משמש לציון המועד שבו הסמנטיקה של רכיב מורכב משתנה.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של ערכים
(I2) name קבוע מהסוג string
(I3) composite_attributes מילון מאפיינים
(I4) decomposition קבוע מהסוג string
(I5) version קבוע מהסוג si32

פלט

שם סוג
results מספר משתנה של ערכים

מגבלות

  • ‫(C1) is_namespaced_op_name(name)
  • ‫(C2) is_defined_in_parent_scope(decomposition)
  • ‫(C3) types(inputs...) == input_types(decomposition)
  • ‫(C4) types(results...) == output_types(decomposition)

דוגמאות

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
 < ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32

 דוגמאות נוספות

concatenate

סמנטיקה

משרשר את inputs לאורך המימד dimension באותו סדר של הארגומנטים שצוינו ומפיק טנסור result. בצורה רשמית יותר, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], כאשר:

  1. id = d0 + ... + dk-1 + kd.
  2. d שווה ל-dimension, ו-d0, ... הם גדלי המאפיינים d של inputs.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C6)
(I2) dimension קבוע מהסוג si64 (C2), (C4), (C6)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C5-C6)

מגבלות

  • (C1) same(element_type(inputs...)).
  • ‫(C2) same(shape(inputs...)) למעט dim(inputs..., dimension).
  • ‫(C3) 0 < size(inputs).
  • ‫(C4) 0 <= dimension < rank(inputs[0]).
  • ‫(C5) element_type(result) = element_type(inputs[0]).
  • ‫(C6) shape(result) = shape(inputs[0]) except for:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

דוגמאות

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 דוגמאות נוספות

קבוע

סמנטיקה

הפונקציה יוצרת טנסור output מקבוע value.

קלט

תווית שם סוג מגבלות
(I1) value קבוע (C1)

פלט

שם סוג מגבלות
output טנזור או טנזור שעבר קוונטיזציה (C1)

מגבלות

  • (C1) type(value) = type(output).

דוגמאות

%output = "stablehlo.constant"() {
  val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]

 דוגמאות נוספות

להשלים המרה

סמנטיקה

מבצעת המרה מטיפוס אלמנט אחד לטיפוס אלמנט אחר בטנזור operand ומפיקה טנזור result.

בהמרות boolean-to-any-supported-type, הערך false מומר לאפס והערך true מומר לאחד. בהמרות של any-supported-type-to-boolean, ערך אפס מומר ל-false, וערכים שאינם אפס מומרים ל-true. בהמשך מוסבר איך זה עובד עם סוגים מורכבים.

בהמרות שכוללות מספר שלם למספר שלם, מספר שלם למספר נקודה צפה או מספר נקודה צפה למספר נקודה צפה, אם אפשר לייצג את ערך המקור בדיוק בסוג היעד, ערך התוצאה הוא הייצוג המדויק הזה. אחרת, ההתנהגות תוגדר בהמשך (#180).

בהמרות שכוללות floating-point-to-integer, החלק השברי מצומצם. אם אי אפשר לייצג את הערך שקוצץ בסוג היעד, ההתנהגות תוגדר בהמשך (#180).

המרות שכוללות מספרים מרוכבים למספרים מרוכבים מתבצעות באותו אופן כמו המרות של מספרים ממשיים למספרים ממשיים, כשממירים את החלקים הממשיים והדמיוניים.

בהמרות מסוג complex-to-any-other-type ו-any-other-type-to-complex, המערכת מתעלמת מהערך הדמיוני של המקור או מאפסת את הערך הדמיוני של היעד, בהתאמה. ההמרה של החלק האמיתי מתבצעת לפי המרות של נקודה צפה.

באופן עקרוני, הפעולה הזו יכולה לבצע דה-קוונטיזציה (המרה של טנסורים שעברו קוונטיזציה לטנסורים רגילים), קוונטיזציה (המרה של טנסורים רגילים לטנסורים שעברו קוונטיזציה) וקוונטיזציה מחדש (המרה בין טנסורים שעברו קוונטיזציה), אבל כרגע יש לנו פעולות ייעודיות לכך – uniform_dequantize לתרחיש השימוש הראשון ו-uniform_quantize לתרחישי השימוש השני והשלישי. יכול להיות שבעתיד שתי הפעולות האלה יאוחדו ל-convert (מס' 1576).

קלט

תווית שם סוג מגבלות
(I1) operand tensor (C1)

פלט

שם סוג מגבלות
result tensor (C1)

מגבלות

  • (C1) shape(operand) = shape(result).

דוגמאות

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 דוגמאות נוספות

קונבולוציה

סמנטיקה

מחשבת מכפלה סקלרית בין חלונות של lhs ופרוסות של rhs ומפיקה result. בתרשים הבא מוצגות דוגמאות קונקרטיות שממחישות איך מחשבים את הרכיבים ב-result מ-lhs ומ-rhs.

קונבולוציה

באופן רשמי יותר, אפשר להגדיר מחדש את הקלט במונחים של lhs כדי להגדיר חלונות של lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

השינוי הזה מתבצע באמצעות פונקציות העזר הבאות:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] כשj[d] = i[permutation[d]].

אם feature_group_count = 1 ו-batch_group_count = 1, אז לכל output_spatial_index ב-index_space(dim(result, output_spatial_dimensions...)),‏ result[result_shape(:, output_spatial_index, :)] = dot_product כאשר:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). נראה שהתכונה הזו לא נמצאת בשימוש, ולכן אנחנו מתכננים להסיר אותה בעתיד (מס' 1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

אם feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

אם batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

עבור סוגים היברידיים של קוונטיזציה, הפעולה היא hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs טנזור או טנזור שעבר קוונטיזציה (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides קבוע טנסור חד-ממדי מסוג si64 (C2-C3), (C25)
(I4) padding קבוע טנסור דו-ממדי מסוג si64 (C4), (C25)
(I5) lhs_dilation קבוע טנסור חד-ממדי מסוג si64 (C5-C6), (C25)
(I6) rhs_dilation קבוע טנסור חד-ממדי מסוג si64 (C7-C8), (C25)
(I7) window_reversal קבוע טנסור חד-ממדי מסוג i1 (C9)
(I8) input_batch_dimension קבוע מהסוג si64 (C10), (C13), (C25)
(I9) input_feature_dimension קבוע מהסוג si64 (C11), (C13-C14)
(I10) input_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 ‫(C12),‏ (C13),‏ (C25)
(I11) kernel_input_feature_dimension קבוע מהסוג si64 ‪(C14), (C18)
(I12) kernel_output_feature_dimension קבוע מהסוג si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 ‪(C17-C18), (C25)
(I14) output_batch_dimension קבוע מהסוג si64 (C20), (C25)
(I15) output_feature_dimension קבוע מהסוג si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 (C19-C20), (C25)
(I17) feature_group_count קבוע מהסוג si64 ‪(C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count קבוע מהסוג si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST (C24)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה ‫(C25-C28), (C30), (C32-34)

מגבלות

  • (C1) N = rank(lhs) = rank(rhs).
  • ‫(C2) size(window_strides) = N - 2.
  • ‫(C3) 0 < window_strides.
  • ‫(C4) shape(padding) = [N - 2, 2].
  • ‫(C5) size(lhs_dilation) = N - 2.
  • ‫(C6) 0 < lhs_dilation.
  • ‫(C7) size(rhs_dilation) = N - 2.
  • ‫(C8) 0 < rhs_dilation.
  • ‫(C9) size(window_reversal) = N - 2.
  • ‫(C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • ‫(C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • ‫(C13) בהינתן input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • ‫(C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • ‫(C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • ‫(C17) size(kernel_spatial_dimensions) = N - 2.
  • ‫(C18) בהינתן kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • ‫(C19) size(output_spatial_dimensions) = N - 2.
  • ‫(C20) בהינתן output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • ‫(C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • ‫(C23) feature_group_count = 1 or batch_group_count = 1.
  • ‫(C24) size(precision_config) = 2.
  • ‫(C25) dim(result, result_dim) מוגדר כך:
    • dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) if result_dim = output_feature_dimension.
    • num_windows אחרת, כאשר:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • אם הפעולה משתמשת בטנסורים לא מכומתים:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
    • ‫(C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • ‫(C29) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = kernel_output_feature_dimension.
    • ‫(C30) אם is_per_axis_quantized(result), אז quantization_dimension(result) = output_feature_dimension.
    • אם is_quantized(lhs):
    • ‫(C31) storage_type(lhs) = storage_type(rhs).
    • ‫(C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • ‫(C33) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

דוגמאות

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strid<es = arra>yi64: 4, 4,
  paddi<n>g = dense<0 : ten>sor2x2xi64,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  // In the StableHLO dialect, dimension numbers are encoded vi<a:
  // `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" a<re spatial dimensions.
  d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
  batch_group_count = 1 : i64,
  fea<ture_group_count >= 1 : i64,
 < precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 דוגמאות נוספות

קוסינוס

סמנטיקה

מבצעת פעולת קוסינוס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: cos מ-IEEE-754.
  • למספרים מרוכבים: קוסינוס מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(cosine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 דוגמאות נוספות

count_leading_zeros

סמנטיקה

מבצעת ספירה של מספר הביטים המובילים של אפס בכל רכיב בטנזור operand ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]

 דוגמאות נוספות

custom_call

סמנטיקה

מכיל פעולה שמוגדרת בהטמעה call_target_name שמקבלת את inputs ואת called_computations ומפיקה את results. אפשר להשתמש ב-has_side_effect,‏ backend_config ו-api_version כדי לספק מטא-נתונים נוספים שמוגדרים על ידי ההטמעה.

בשלב הזה, הפעולה הזו מכילה אוסף לא מאורגן של מטא-נתונים שמשקפים את ההתפתחות האורגנית של הפעולה המקבילה לה בקומפיילר XLA. בעתיד, אנחנו מתכננים לאחד את המטא-נתונים האלה (#741).

קלט

תווית שם סוג
(I1) inputs מספר משתנה של ערכים
(I2) call_target_name קבוע מהסוג string
(I3) has_side_effect קבוע מהסוג i1
(I4) backend_config קבוע מהסוג string או מילון מאפיינים
(I5) api_version קבוע מהסוג si32
(I6) called_computations מספר משתנה של קבועים מהסוג string
(I7) output_operand_aliases מציינים את החלקים של ה-aliasing בפלטים ובאופרנדים

פלט

שם סוג
results מספר משתנה של ערכים

‫(XLA GPU Support) Special custom_call targets

יש שלושה סוגים מיוחדים של call_target_name שקשורים ל-buffer: ‫CreateBuffer יוצר buffer לא מאותחל, Pin יוצר buffer מאותחל ו-Unpin מבטל את ההקצאה של buffer ומחזיר את התוכן של buffer.

%uninitialized_buffer = "stablehlo.custom_call"() {
  call_target_name = "CreateBuffer",
  api_version> = 4 : <i32,
>} : () - memref4xf64

%initialized_buffer = "stablehlo.custom_call"(%init_value) {
  call_target_name = "Pin&quo<t;,
 > ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64

%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
  call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
  api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64

כינוי

יכול להיות שחלק מהפעולות של custom_call ידרשו שחלק מהפלט וחלק מהאופרנדים ישתפו את אותו הזיכרון. אפשר לבטא את זה באמצעות output_operand_aliases. ייצוג של זוג כינויים מורכב מרשימה של אינדקסים של טאפלים של פלט שמייצגים את חלק הפלט, ומ-operand_index יחד עם רשימה של אינדקסים של טאפלים של אופרנד שמייצגים את חלק האופרנד. רשימת הפלט או המדדים של טופל האופרנד ריקה אם הסוג המתאים הוא לא סוג tuple, והיא יכולה להיות ארוכה ככל שרוצים עבור סוג טופל מקונן. הייצוג הזה דומה לייצוג הכינוי XLA.

הסוג של חלק הפלט וחלק הקלט בצמד כינויים חייב להיות זהה. בפעולות custom_call שאינן קריאה אל CreateBuffer, Pin ו-Unpin, אופרנד buffer יכול להופיע בזוג אחד לכל היותר של כינויים, ופלט buffer חייב להופיע בזוג אחד של כינויים.

דוגמאות

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64

%updated_buffer = "stablehlo.custom_call"(%buffer) {
  call_target_name = "Update",
  api_version = 4 : i32,
  output_operand_aliases< = [
    #stablehlo.output_operand_aliasoutput_tuple_indices = [],
      operand_ind>ex = 0,
     < oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64

חילוק

סמנטיקה

מבצעת חילוק של כל רכיב בטנזור המחולק lhs בכל רכיב בטנזור המחלק rhs, ומחזירה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים: חלוקת מספרים שלמים שמניבה את המנה האלגברית בלי החלק השברי.
  • למספרים ממשיים: division מ-IEEE-754.
  • למספרים מרוכבים: חילוק מרוכב.
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)
(I2) rhs טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 דוגמאות נוספות

dot_general

סמנטיקה

מחשב מכפלה סקלרית בין פרוסות של lhs לבין פרוסות של rhs ומפיק טנזור result.

באופן רשמי יותר, result[result_index] = dot_product, כאשר:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index כאשר size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) ו- size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

עבור סוגים היברידיים של קוונטיזציה, הפעולה היא hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config שולט בפשרה בין מהירות לדיוק בחישובים במערכות עורפיות של מאיצים. אפשר להשתמש באחד מהערכים הבאים (בשלב הזה, הסמנטיקה של ערכי ה-enum האלה לא מוגדרת באופן מלא, אבל אנחנו מתכננים לטפל בזה ב-#755):

  • DEFAULT: החישוב הכי מהיר, אבל הקירוב הכי פחות מדויק למספר המקורי.
  • HIGH: חישוב איטי יותר, אבל קירוב מדויק יותר למספר המקורי.
  • HIGHEST: החישוב הכי איטי, אבל הקירוב הכי מדויק למספר המקורי.

DotAlgorithm מגדיר את המאפיינים העיקריים של האלגוריתם שמשמש להטמעה של פעולת הנקודה, ומגדיר גם את הדיוק. אם שדות מאפייני האלגוריתם מוגדרים, הערך של precision_config חייב להיות DEFAULT. DotAlgorithms אין להם ערך ברירת מחדל, כי פרמטרי ברירת המחדל מוגדרים בהטמעה. לכן, אפשר להגדיר את כל השדות של אלגוריתם הנקודות לערך None כדי לציין אלגוריתם נקודות ריק, שבמקומו ישמש הערך precision_config.

השדות DotAlgorithm כוללים:

  • lhs_precision_type ו-rhs_precision_type, רמות הדיוק שאליהן מעוגלים הצד הימני והצד השמאלי של הפעולה. סוגי הדיוק לא תלויים בסוגי האחסון של נתוני הקלט והפלט.
  • accumulation_type רמת הדיוק שמשמשת לחישוב המצטבר.
  • lhs_component_count, rhs_component_count ו-num_primitive_operations רלוונטיים כשאנחנו מפעילים אלגוריתם שמפרק את הצד הימני או השמאלי של המשוואה למספר רכיבים ומבצע מספר פעולות מכפלה סקלרית 'פרימיטיביות' על הערכים האלה – בדרך כלל כדי לדמות דיוק גבוה יותר (לדוגמה, שימוש בסוג הנתונים של בינה מלאכותית bfloat16 לחישובים ברמת דיוק גבוהה יותר: bf16_6x tf32_3x וכו'). עבור אלגוריתמים ללא פירוק, הערכים האלה צריכים להיות 1.
  • allow_imprecise_accumulation כדי לציין אם מותר לצבור נתונים ברמת דיוק נמוכה יותר בחלק מהשלבים (למשל CUBLASLT_MATMUL_DESC_FAST_ACCUM).

מאפיינים DotAlgorithm לדוגמה:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

ההטמעות יקבעו אילו שילובים נתמכים. באופן כללי, אין ערובה לכך שכל אלגוריתם נתמך בכל סוג של מאיץ על ידי הצרכן של StableHLO. אם אלגוריתם מסוים לא נתמך, צריך להציג שגיאה ולא לחזור לאלגוריתם חלופי. האימות של StableHLO יספק אימות מיטבי, וימנע שימוש באלגוריתמים שלא ידוע שהם נתמכים בחומרה כלשהי.

במאמר xla_data.proto > Algorithm מפורטים כמה ערכים נתמכים של אלגוריתמים. בכרטיס מספר 2483 מתואר התוכנית ליצירת מסמך מרכזי בנושא אלגוריתמים נתמכים לפי קצה עורפי.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs טנזור או טנזור שעבר קוונטיזציה (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions קבוע טנסור חד-ממדי מסוג si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions קבוע טנסור חד-ממדי מסוג si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions קבוע טנסור חד-ממדי מסוג si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST (C11), (C21)
(I8) lhs_precision_type ‫FloatType או TensorFloat32 (C21)
(I9) rhs_precision_type ‫FloatType או TensorFloat32 (C21)
(I10) accumulation_type ‫FloatType או TensorFloat32 (C21)
(I11) lhs_component_count קבוע מהסוג si32 (C21), (C22)
(I12) rhs_component_count קבוע מהסוג si32 (C21), (C23)
(I13) num_primitive_operations קבוע מהסוג si32 (C21), (C24)
(I14) allow_imprecise_accumulation קבוע מהסוג bool (C21)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה ‫(C12),‏ (C14),‏ (C18-C20)

מגבלות

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • ‫(C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • ‫(C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • ‫(C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • ‫(C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • ‫(C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • ‫(C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • ‫(C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • ‫(C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • ‫(C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • ‫(C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • אם הפעולה משתמשת בטנסורים לא מכומתים:
    • ‫(C13) element_type(lhs) = element_type(rhs).
  • אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • ‫(C15) zero_points(rhs) = 0.
    • ‫(C16) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) לא ב-rhs_contracting_dimensions.
    • אם is_quantized(lhs):
    • ‫(C17) storage_type(lhs) = storage_type(rhs).
    • ‫(C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • ‫(C19) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • ‫(C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • אם !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • ‫(C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • ‫(C23) 0 < rhs_component_count.
    • ‫(C24) 0 < num_primitive_operations.

דוגמאות

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #sta<blehlo.dot
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimension>s = [1]
  ,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
  algorithm = #stablehlo.dot<_algorithm
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation >= false
  
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 דוגמאות נוספות

dynamic_broadcast_in_dim

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות ל-broadcast_in_dim op, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_dimensions.

הפעולה מקבלת גם מאפיינים אופציונליים known_expanding_dimensions, known_nonexpanding_dimensions כדי לבטא ידע סטטי לגבי התנהגות ההרחבה של המאפיינים. אם לא מציינים מאפיינים, המערכת מניחה שכל המאפיינים יכולים להתרחב.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C2), (C5-C6), (C9)
(I2) output_dimensions טנזור חד-ממדי מסוג מספר שלם (C7)
(I3) broadcast_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C2-C6)
(I4) known_expanding_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C8-C9)
(I5) known_nonexpanding_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C8-C9)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1), (C3), (C5-C7)

מגבלות

  • ‫(C1) element_type(result) מחושב לפי הנוסחה הבאה:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand), למעט quantization_dimension(operand),‏ scales(operand) ו-zero_points(operand) שעשויים להיות שונים מ-quantization_dimension(result),‏ scales(result) ו-zero_points(result) בהתאמה, אחרת.
  • ‫(C2) size(broadcast_dimensions) = rank(operand).
  • ‫(C3) 0 <= broadcast_dimensions < rank(result).
  • ‫(C4) is_unique(broadcast_dimensions).
  • ‫(C5) לכל d בaxes(operand):
    • dim(operand, d) = 1 או
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • ‫(C6) אם is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • אם dim(operand, quantization_dimension(operand)) = 1, אז scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • ‫(C7) size(output_dimensions) = rank(result).
  • ‫(C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • ‫(C9) 0 <= known_expanding_dimensions < rank(operand).
  • ‫(C10) 0 <= known_nonexpanding_dimensions < rank(operand).

דוגמאות

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensio<ns = arra>yi64: 2, 1,
  known_expanding_dimensio<ns = a>rrayi64: 0,
  known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 דוגמאות נוספות

dynamic_conv

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות לפעולת הקונבולוציה, אבל הריפוד מוגדר באופן דינמי באמצעות padding.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs טנזור או טנזור שעבר קוונטיזציה ‫(C1),‏ (C14-C16),‏ (C26-C28),‏ (C30-C33)
(I3) padding טנזור דו-ממדי מסוג מספר שלם ‫(C4)
(I4) window_strides קבוע טנסור חד-ממדי מסוג si64 (C2-C3)
(I5) lhs_dilation קבוע טנסור חד-ממדי מסוג si64 (C5-C6)
(I6) rhs_dilation קבוע טנסור חד-ממדי מסוג si64 (C7-C8)
(I7) window_reversal קבוע טנסור חד-ממדי מסוג i1 (C9)
(I8) input_batch_dimension קבוע מהסוג si64 (C10), (C13)
(I9) input_feature_dimension קבוע מהסוג si64 (C11), (C13-C14)
(I10) input_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 (C12), (C13)
(I11) kernel_input_feature_dimension קבוע מהסוג si64 ‪(C14), (C18)
(I12) kernel_output_feature_dimension קבוע מהסוג si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 (C17-C18)
(I14) output_batch_dimension קבוע מהסוג si64 ‫(C20)
(I15) output_feature_dimension קבוע מהסוג si64 (C20), (C29)
(I16) output_spatial_dimensions קבוע טנסור חד-ממדי מסוג si64 (C19-C20)
(I17) feature_group_count קבוע מהסוג si64 ‪(C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count קבוע מהסוג si64 (C10), (C15), (C22), (C23)
(I19) precision_config מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST (C24)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C25-C27), (C29), (C31-C33)

מגבלות

  • (C1) N = rank(lhs) = rank(rhs).
  • ‫(C2) size(window_strides) = N - 2.
  • ‫(C3) 0 < window_strides.
  • ‫(C4) shape(padding) = [N - 2, 2].
  • ‫(C5) size(lhs_dilation) = N - 2.
  • ‫(C6) 0 < lhs_dilation.
  • ‫(C7) size(rhs_dilation) = N - 2.
  • ‫(C8) 0 < rhs_dilation.
  • ‫(C9) size(window_reversal) = N - 2.
  • ‫(C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • ‫(C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • ‫(C13) בהינתן input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • ‫(C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • ‫(C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • ‫(C17) size(kernel_spatial_dimensions) = N - 2.
  • ‫(C18) בהינתן kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • ‫(C19) size(output_spatial_dimensions) = N - 2.
  • ‫(C20) בהינתן output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • ‫(C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • ‫(C23) feature_group_count = 1 or batch_group_count = 1.
  • ‫(C24) size(precision_config) = 2.
  • ‫(C25) dim(result, result_dim) מוגדר כך:
    • dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) if result_dim = output_feature_dimension.
    • num_windows אחרת, כאשר:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • אם הפעולה משתמשת בטנסורים לא מכומתים:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
    • ‫(C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • ‫(C29) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = kernel_output_feature_dimension.
    • ‫(C30) אם is_per_axis_quantized(result), אז quantization_dimension(result) = output_feature_dimension.
    • אם is_quantized(lhs):
    • ‫(C31) storage_type(lhs) = storage_type(rhs).
    • ‫(C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • ‫(C33) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

דוגמאות

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strid<es = arra>yi64: 4, 4,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  dimension_numbers = #stab<lehlo.convraw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions => [1, 2]
  ,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

 דוגמאות נוספות

dynamic_gather

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות שלה לפעולה gather, כאשר slice_sizes מוגדר באופן דינמי כערך.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices טנזור מסוג מספר שלם (C2), (C3), (C13)
(I3) slice_sizes טנזור חד-ממדי מסוג מספר שלם (C8), (C11-C13)
(I4) offset_dims קבוע טנסור חד-ממדי מסוג si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims קבוע טנסור חד-ממדי מסוג si64 (C1), (C6-C8), (C13)
(I6) start_index_map קבוע טנסור חד-ממדי מסוג si64 (C3), (C9), (C10)
(I7) index_vector_dim קבוע מהסוג si64 (C2), (C3), (C13)
(I8) indices_are_sorted קבוע מהסוג i1

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C5), (C13-C14)

מגבלות

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • ‫(C2) 0 <= index_vector_dim <= rank(start_indices).
  • ‫(C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • ‫(C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • ‫(C5) 0 <= offset_dims < rank(result).
  • ‫(C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • ‫(C7) 0 <= collapsed_slice_dims < rank(operand).
  • ‫(C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • ‫(C9) is_unique(start_index_map).
  • ‫(C10) 0 <= start_index_map < rank(operand).
  • ‫(C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • ‫(C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) כאשר:
    • batch_dim_sizes = shape(start_indices) חוץ מזה שגודל המאפיין start_indices שמתאים ל-index_vector_dim לא נכלל.
    • offset_dim_sizes = shape(slice_sizes), למעט גדלי המימדים ב-slice_sizes שתואמים ל-collapsed_slice_dims.
    • combine ממקם את batch_dim_sizes על צירים שתואמים ל-batch_dims ואת offset_dim_sizes על צירים שתואמים ל-offset_dims.
  • (C14) element_type(operand) = element_type(result).

דוגמאות

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vect>or_dim = 2,
  indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

 דוגמאות נוספות

dynamic_iota

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות ל-iota op, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_shape.

קלט

תווית שם סוג מגבלות
(I1) output_shape טנזור חד-ממדי מסוג מספר שלם (C1), (C2)
(I2) iota_dimension si64 (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C2)

מגבלות

  • (C1) 0 <= iota_dimension < size(output_shape).
  • ‫(C2) rank(result) = size(output_shape).

דוגמאות

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 דוגמאות נוספות

dynamic_pad

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות לפעולה pad, אבל הערכים edge_padding_low, edge_padding_high ו-interior_padding מוגדרים באופן דינמי כערכים.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C2), (C4)
(I2) padding_value טנזור אפס-ממדי או טנזור שעבר קוונטיזציה ברמת הטנזור (C1)
(I3) edge_padding_low טנזור חד-ממדי מסוג מספר שלם (C1), (C4)
(I4) edge_padding_high טנזור חד-ממדי מסוג מספר שלם (C1), (C4)
(I5) interior_padding טנזור חד-ממדי מסוג מספר שלם (C2-C4)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C3-C6)

מגבלות

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • ‫(C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • ‫(C3) 0 <= interior_padding.
  • ‫(C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

דוגמאות

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 דוגמאות נוספות

dynamic_reshape

סמנטיקה

הפעולה הזו זהה מבחינת הפונקציונליות ל-op‏ reshape, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_shape.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C3)
(I2) output_shape טנזור חד-ממדי מסוג מספר שלם ‫(C4)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1-C4)

מגבלות

  • ‫(C1) element_type(result) מחושב לפי הנוסחה הבאה:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) חוץ מהמקרים שבהם quantization_dimension(operand) וquantization_dimension(result) שונים.
  • ‫(C2) size(operand) = size(result).
  • ‫(C3) אם is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • ‫(C4) size(output_shape) = rank(result).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]

 דוגמאות נוספות

dynamic_slice

סמנטיקה

מחלקים את operand לחלקים באמצעות חישוב דינמי של אינדקסים התחלתיים ומפיקים טנזור result. ‫start_indices מכיל את אינדקס ההתחלה של הפלח לכל מאפיין, בכפוף להתאמה פוטנציאלית, ו-slice_sizes מכיל את הגודל של הפלח לכל מאפיין. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C2), (C4)
(I2) start_indices מספר משתנה של טנסורים אפס-ממדיים מסוג מספר שלם (C2), (C3)
(I3) slice_sizes קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C5)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1), (C5)

מגבלות

  • (C1) element_type(operand) = element_type(result).
  • ‫(C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • ‫(C3) same(type(start_indices...)).
  • ‫(C4) 0 <= slice_sizes <= shape(operand).
  • ‫(C5) shape(result) = slice_sizes.

דוגמאות

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 דוגמאות נוספות

dynamic_update_slice

סמנטיקה

יוצר טנזור result ששווה לטנזור operand, למעט הפרוסה שמתחילה ב-start_indices, שמעודכנת עם הערכים ב-update. באופן רשמי יותר, result[result_index] מוגדר כך:

  • update[update_index] אם 0 <= update_index < shape(update) כאשר:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1-C4), (C6)
(I2) update tensor או per-tensor quantized tensor (C2), (C3), (C6)
(I3) start_indices מספר משתנה של טנסורים אפס-ממדיים מסוג מספר שלם (C4), (C5)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1)

מגבלות

  • (C1) type(operand) = type(result).
  • ‫(C2) element_type(update) = element_type(operand).
  • ‫(C3) rank(update) = rank(operand).
  • ‫(C4) size(start_indices) = rank(operand).
  • ‫(C5) same(type(start_indices...)).
  • ‫(C6) 0 <= shape(update) <= shape(operand).

דוגמאות

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
 < : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

 דוגמאות נוספות

מעריכיות

סמנטיקה

מבצעת פעולה מעריכית על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: exp מ-IEEE-754.
  • למספרים מרוכבים: חזקה מרוכבת.
  • לסוגים כמותיים: dequantize_op_quantize(exponential, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 דוגמאות נוספות

exponential_minus_one

סמנטיקה

מבצע פעולת חיסור של 1 מכל רכיב של טנסור operand, ומחזיר טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: expm1 מ-IEEE-754.
  • למספרים מרוכבים: פונקציית האקספוננט המרוכב פחות אחד.
  • לסוגים כמותיים: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]

 דוגמאות נוספות

fft

סמנטיקה

מבצעת את התמרת פורייה הישירה וההפוכה עבור קלט/פלט ממשי ומרוכב.

fft_type הוא אחד מהבאים:

  • FFT: העברה של FFT מורכב למורכב.
  • IFFT: FFT מורכב-למורכב הפוך.
  • RFFT: העברה של FFT מורכב בזמן אמת.
  • IRFFT: FFT הפוך של מספרים ממשיים למספרים מרוכבים (כלומר, מקבל מספרים מרוכבים ומחזיר מספרים ממשיים).

באופן רשמי יותר, בהינתן הפונקציה fft שמקבלת טנסורים חד-ממדיים של סוגים מורכבים כקלט, יוצרת טנסורים חד-ממדיים של אותם סוגים כפלט ומחשבת את התמרת פורייה הדיסקרטית:

במקרה של fft_type = FFT, ‏ result מוגדר כתוצאה הסופית של סדרת חישובים שבהם L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

בנוסף, בהינתן הפונקציה ifft שיש לה אותה חתימת סוג והיא מחשבת את ההופכי של fft:

במקרה של fft_type = IFFT, ‏ result מוגדרת כהיפוך של החישובים של fft_type = FFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

בנוסף, נתונה הפונקציה rfft שמקבלת טנסורים חד-ממדיים של סוגי נתונים מסוג נקודה צפה, יוצרת טנסורים חד-ממדיים של סוגים מורכבים עם אותה סמנטיקה של נקודה צפה ופועלת באופן הבא:

  • rfft(real_operand) = truncated_result where
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(כשמחשבים את התמרת פורייה הדיסקרטית עבור אופרנדים ממשיים, ‏N/2 + 1 הרכיבים הראשונים של התוצאה מגדירים באופן חד-משמעי את שאר התוצאה, ולכן התוצאה של rfft נחתכת כדי להימנע מחישוב של רכיבים מיותרים).

במקרה של fft_type = RFFT, ‏ result מוגדר כתוצאה הסופית של סדרת חישובים שבהם L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

לבסוף, בהינתן הפונקציה irfft שיש לה אותה חתימת סוג והיא מחשבת את ההופכי של rfft:

במקרה של fft_type = IRFFT, ‏ result מוגדרת כהיפוך של החישובים של fft_type = RFFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מסוג מרוכב (C1), (C2), (C4), (C5)
(I2) fft_type enum of FFT, IFFT, RFFT, and IRFFT (C2), (C5)
(I3) fft_length קבוע טנסור חד-ממדי מסוג si64 (C1), (C3), (C4)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מסוג מרוכב (C2), (C4), (C5)

מגבלות

  • (C1) size(fft_length) <= rank(operand).
  • ‫(C2) הקשר בין סוגי הרכיבים operand ו-result משתנה:
    • אם fft_type = FFT, ‏ element_type(operand) ו-element_type(result) הם מאותו סוג מורכב.
    • אם fft_type = IFFT, ‏ element_type(operand) ו-element_type(result) הם מאותו סוג מורכב.
    • אם fft_type = RFFT, element_type(operand) הוא סוג של נקודה צפה ו-element_type(result) הוא סוג מורכב של אותה סמנטיקה של נקודה צפה.
    • אם fft_type = IRFFT, element_type(operand) הוא סוג מורכב ו-element_type(result) הוא סוג של נקודה צפה עם אותה סמנטיקה של נקודה צפה.
  • ‫(C3) 1 <= size(fft_length) <= 3.
  • ‫(C4) אם בין operand ל-result יש טנזור real מסוג נקודה צפה, אז shape(real)[-size(fft_length):] = fft_length.
  • ‫(C5) shape(result) = shape(operand) except for:
    • אם fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • אם fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

דוגמאות

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = <#stablehloff>t_type FFT,
  fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

קומה

סמנטיקה

מבצעת פעולת floor על כל רכיב בטנזור operand ומפיקה טנזור result. מיישמת את הפעולה roundToIntegralTowardNegative מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(floor, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 דוגמאות נוספות

לרכז

סמנטיקה

אוספת פרוסות מטנזור operand מההיסטים שצוינו ב-start_indices ומפיקה טנזור result.

בתרשים הבא מוצג מיפוי של רכיבים ב-result לרכיבים ב-operand באמצעות דוגמה קונקרטית. בתרשים מוצגים כמה מדדים לדוגמה result והסבר מפורט לאילו מדדים operand הם מתאימים.

לרכז

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • ההגדרה של start_index היא:
    • start_indices[bi0, ..., :, ..., biN] כאשר bi הם רכיבים נפרדים ב-batch_index ו-: מוכנס באינדקס index_vector_dim, אם index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] אחרת.
  • ‫for d_operand in axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) if d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 אחרת.
  • ‫for d_operand in axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] אם d_operand = operand_batching_dims[i_batching] וגם d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 אחרת.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] כאשר oi הם רכיבים נפרדים ב-offset_index, ו-0 מוכנס באינדקסים מ-collapsed_slice_dims ועד operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

אם הערך של indices_are_sorted הוא true, ההטמעה יכולה להניח שהערכים של start_indices ממוינים לפי start_index_map. אחרת, ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor ‪(C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices טנזור מסוג מספר שלם (C2-C3), (C14), (C17), (C22)
(I3) offset_dims קבוע טנסור חד-ממדי מסוג si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims קבוע טנסור חד-ממדי מסוג si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims קבוע טנסור חד-ממדי מסוג si64 ‫(C1),‏ (C6),‏ (C10-C12),‏ (C16-C18),‏ (C22)
(I6) start_indices_batching_dims קבוע טנסור חד-ממדי מסוג si64 (C13-C17)
(I7) start_index_map קבוע טנסור חד-ממדי מסוג si64 (C3), (C18-C19)
(I8) index_vector_dim קבוע מהסוג si64 (C2-C3), (C15), (C22)
(I9) slice_sizes קבוע טנסור חד-ממדי מסוג si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted קבוע מהסוג i1

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C5), (C22-C23)

מגבלות

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • ‫(C2) 0 <= index_vector_dim <= rank(start_indices).
  • ‫(C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • ‫(C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • ‫(C5) 0 <= offset_dims < rank(result).
  • ‫(C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • ‫(C7) is_sorted(collapsed_slice_dims).
  • ‫(C8) 0 <= collapsed_slice_dims < rank(operand).
  • ‫(C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • ‫(C10) is_sorted(operand_batching_dims).
  • ‫(C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • ‫(C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • ‫(C15) index_vector_dim not in start_indices_batching_dims.
  • ‫(C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • ‫(C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • ‫(C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • ‫(C19) 0 <= start_index_map < rank(operand).
  • ‫(C20) size(slice_sizes) = rank(operand).
  • ‫(C21) 0 <= slice_sizes <= shape(operand).
  • ‫(C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) כאשר:
    • batch_dim_sizes = shape(start_indices) חוץ מזה שגודל המאפיין start_indices שמתאים ל-index_vector_dim לא נכלל.
    • offset_dim_sizes = slice_sizes except that the dimension sizes in slice_sizes corresponding to collapsed_slice_dims and operand_batching_dims are not included.
    • combine ממקם את batch_dim_sizes על צירים שתואמים ל-batch_dims ואת offset_dim_sizes על צירים שתואמים ל-offset_dims.
  • ‫(C23) element_type(operand) = element_type(result).

דוגמאות

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vect>or_dim = 3,
  slice_siz<es = arrayi64: >1, 1, 2, 2,
  indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 דוגמאות נוספות

get_dimension_size

סמנטיקה

הפונקציה מחזירה את הגודל של dimension שצוין ב-operand. באופן רשמי יותר, result = dim(operand, dimension). הסמנטיקה מתייחסת רק לרכיב הצורה של הסוג. סוג הרכיב יכול להיות כל דבר.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1)
(I2) dimension קבוע מהסוג si64 (C1)

פלט

שם סוג
result טנזור אפס-ממדי מסוג si32

מגבלות

  • (C1) 0 <= dimension < rank(operand).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3

 דוגמאות נוספות

get_tuple_element

סמנטיקה

מחזירה את הרכיב במיקום index של ה-tuple‏ operand ויוצרת result. באופן רשמי יותר, result = operand[index].

קלט

תווית שם סוג מגבלות
(I1) operand tuple (C1), (C2)
(I2) index קבוע מהסוג si32 (C1), (C2)

פלט

שם סוג מגבלות
result כל ערך (C2)

מגבלות

  • (C1) 0 <= index < size(operand).
  • ‫(C2) type(result) = tuple_element_types(operand)[index].

דוגמאות

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]

 דוגמאות נוספות

אם

סמנטיקה

הפונקציה מפיקה את הפלט מהפעלת פונקציה אחת בדיוק מתוך true_branch או false_branch, בהתאם לערך של pred. באופן רשמי יותר, result = pred ? true_branch() : false_branch().

קלט

תווית שם סוג מגבלות
(I1) pred טנזור אפס-ממדי מסוג i1
(I2) true_branch פונקציה (C1-C3)
(I3) false_branch פונקציה (C1), (C2)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה (C3)

מגבלות

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • ‫(C2) output_types(true_branch) = output_types(false_branch).
  • ‫(C3) type(results...) = output_types(true_branch).

דוגמאות

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
  "stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10

 דוגמאות נוספות

imag

סמנטיקה

הפונקציה מחלצת את החלק המדומה, לפי רכיב, מ-operand ומפיקה טנסור result. באופן רשמי יותר, לכל רכיב x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מסוג מרוכב (C1), (C2)

פלט

שם סוג מגבלות
result טנסור מסוג נקודה צפה (C1), (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • ‫(C2) element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) if is_complex(operand).
    • element_type(operand) אחרת.

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]

 דוגמאות נוספות

בפיד

סמנטיקה

קורא נתונים מהפיד ומפיק results.

הסמנטיקה של infeed_config מוגדרת בהטמעה.

results מורכב מערכי מטען ייעודי (payload) שמופיעים ראשונים וטוקן שמופיע אחרון. בעתיד, אנחנו מתכננים לפצל את מטען הנתונים ואת האסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג
(I1) token token
(I2) infeed_config קבוע מהסוג string

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה (C1-C3)

מגבלות

  • (C1) 0 < size(results).
  • ‫(C2) is_empty(result[:-1]) או is_tensor(type(results[:-1])).
  • ‫(C3) is_token(type(results[-1])).

דוגמאות

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 דוגמאות נוספות

iota

סמנטיקה

ממלאת טנסור output בערכים בסדר עולה החל מאפס לאורך המימד iota_dimension. באופן רשמי יותר,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

קלט

תווית שם סוג מגבלות
(I1) iota_dimension si64 (C1)

פלט

שם סוג מגבלות
output טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) 0 <= iota_dimension < rank(output).

דוגמאות

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 דוגמאות נוספות

is_finite

סמנטיקה

הפונקציה מבצעת בדיקה של כל רכיב כדי לראות אם הערך ב-x הוא סופי (כלומר, הוא לא ‎+Inf, ‏ ‎-Inf או NaN) ומפיקה טנסור y. מיישמת את הפעולה isFinite מהמפרט IEEE-754. בסוגים כמותיים, התוצאה היא תמיד true.

קלט

תווית שם סוג מגבלות
(I1) x טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
y tensor of boolean type (C1)

מגבלות

  • (C1) shape(x) = shape(y).

דוגמאות

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]

 דוגמאות נוספות

log

סמנטיקה

מבצעת פעולת לוגריתם לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: log מ-IEEE-754.
  • למספרים מרוכבים: לוגריתם מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(log, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 דוגמאות נוספות

log_plus_one

סמנטיקה

מבצעת פעולה של לוגריתם פלוס אחד על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: logp1 מ-IEEE-754.
  • למספרים מרוכבים: complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))
  • לסוגים כמותיים: dequantize_op_quantize(log_plus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

 דוגמאות נוספות

לוגיסטי

סמנטיקה

מבצע פעולה לוגיסטית ברמת הרכיב בטנזור operand ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: division(1, addition(1, exp(-x))) מ-IEEE-754.
  • למספרים מרוכבים: לוגיסטיקה מורכבת.
  • לסוגים כמותיים: dequantize_op_quantize(logistic, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 דוגמאות נוספות

מפה

סמנטיקה

מחילים פונקציית מיפוי computation על inputs לאורך dimensions ומפיקים טנזור result.

באופן רשמי יותר, result[result_index] = computation(inputs...[result_index]).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C4)
(I2) dimensions קבוע טנסור חד-ממדי מסוג si64 (C3)
(I3) computation פונקציה ‫(C4)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1), (C4)

מגבלות

  • (C1) shape(inputs...) = shape(result).
  • ‫(C2) 0 < size(inputs) = N.
  • ‫(C3) dimensions = range(rank(inputs[0])).
  • ‫(C4) computation הוא מסוג (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> כאשר Ei = element_type(inputs[i]) ו-E' = element_type(result).

דוגמאות

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
    stablehlo.return %<0 :> tensori64
}) {
  dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]

 דוגמאות נוספות

מקסימום

סמנטיקה

מבצע פעולת מקסימום לפי רכיבים בטנסורים lhs ו-rhs ומפיק טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: OR לוגי.
  • למספרים שלמים: מספר שלם מקסימלי.
  • למספרים ממשיים: maximum מ-IEEE-754.
  • למספרים מרוכבים: המקסימום הלקסיקוגרפי של הצמד (real, imaginary). הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560).
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1)
(I2) rhs tensor או per-tensor quantized tensor (C1)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]

 דוגמאות נוספות

מינימום

סמנטיקה

מבצעת פעולת min לפי רכיבים בטנסורים lhs ו-rhs ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: AND לוגי.
  • למספרים שלמים: מספר שלם מינימלי.
  • למספרים ממשיים: minimum מ-IEEE-754.
  • למספרים מורכבים: המינימום הלקסיקוגרפי עבור הצמד (real, imaginary). הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560).
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1)
(I2) rhs tensor או per-tensor quantized tensor (C1)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]

 דוגמאות נוספות

כפל

סמנטיקה

מבצעת מכפלה של שני טנסורים lhs ו-rhs לפי רכיבים ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: AND לוגי.
  • למספרים שלמים: כפל של מספרים שלמים.
  • למספרים ממשיים: multiplication מ-IEEE-754.
  • למספרים מרוכבים: כפל מרוכב.
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor או per-tensor quantized tensor (C1)
(I2) rhs tensor או per-tensor quantized tensor (C1)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]

 דוגמאות נוספות

החלפת חיובי/שלילי

סמנטיקה

מבצעת שלילת כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים עם סימן: שלילה של מספר שלם.
  • למספרים שלמים ללא סימן: bitcast למספר שלם עם סימן, שלילת מספר שלם, bitcast בחזרה למספר שלם ללא סימן.
  • למספרים ממשיים: negate מ-IEEE-754.
  • למספרים מרוכבים: שלילה מרוכבת.
  • לסוגים כמותיים: dequantize_op_quantize(negate, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]

 דוגמאות נוספות

לא

סמנטיקה

מבצעת פעולת NOT על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: שלילה לוגית.
  • למספרים שלמים: NOT ברמת הסיביות.

ארגומנטים

שם סוג מגבלות
operand טנסור מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result טנסור מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]

 דוגמאות נוספות

optimization_barrier

סמנטיקה

התכונה הזו מבטיחה שהפעולות שמייצרות את operand יבוצעו לפני כל הפעולות שתלויות ב-result, ומונעת מהטרנספורמציות של הקומפיילר להעביר פעולות מעבר למחסום. במקרים אחרים, הפעולה היא זהות, כלומר result = operand.

ארגומנטים

שם סוג מגבלות
operand מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה לכל טנסור (C1)

פלט

שם סוג מגבלות
result מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה לכל טנסור (C1)

מגבלות

  • (C1) type(operand...) = type(result...).

דוגמאות

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0

 דוגמאות נוספות

או

סמנטיקה

מבצעת פעולת OR ברמת הרכיב של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לערכים בוליאניים: OR לוגי.
  • למספרים שלמים: ערך OR ברמת הסיביות.

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור מסוג מספר שלם או בוליאני (C1)
(I2) rhs טנזור מסוג מספר שלם או בוליאני (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם או בוליאני (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]

 דוגמאות נוספות

outfeed

סמנטיקה

כותב inputs ל-outfeed ויוצר אסימון result.

הסמנטיקה של outfeed_config מוגדרת בהטמעה.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה
(I2) token token
(I3) outfeed_config קבוע מהסוג string

פלט

שם סוג
result token

דוגמאות

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token

 דוגמאות נוספות

pad

סמנטיקה

מרחיב את operand על ידי הוספת ריפוד סביב הטנזור וגם בין הרכיבים של הטנזור עם הערך padding_value שצוין.

הערכים edge_padding_low ו-edge_padding_high מציינים את כמות הריווח שנוספה בקצה הנמוך (ליד אינדקס 0) ובקצה הגבוה (ליד האינדקס הגבוה ביותר) של כל מימד, בהתאמה. הערך של הריווח יכול להיות שלילי. הערך המוחלט של ריווח שלילי מציין את מספר הרכיבים שיוסרו מהמאפיין שצוין.

interior_padding מציין את כמות הריווח שנוספת בין שני רכיבים בכל מימד, והיא לא יכולה להיות שלילית. ריווח פנימי מתרחש לפני ריווח שוליים, כך שריווח שוליים שלילי יסיר אלמנטים מהאופרנד עם הריווח הפנימי.

באופן רשמי יותר, result[result_index] מוגדר כך:

  • operand[operand_index] if result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C2), (C4)
(I2) padding_value טנזור אפס-ממדי או טנזור שעבר קוונטיזציה ברמת הטנזור (C1)
(I3) edge_padding_low קבוע טנסור חד-ממדי מסוג si64 (C1), (C4)
(I4) edge_padding_high קבוע טנסור חד-ממדי מסוג si64 (C1), (C4)
(I5) interior_padding קבוע טנסור חד-ממדי מסוג si64 (C2-C4)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C3-C6)

מגבלות

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • ‫(C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • ‫(C3) 0 <= interior_padding.
  • ‫(C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

דוגמאות

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_l<ow = arra>yi64: 0, 1,
  edge_padding_hi<gh = arra>yi64: 2, 1,
  interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 דוגמאות נוספות

partition_id

סמנטיקה

יוצרת partition_id של התהליך הנוכחי.

פלט

שם סוג
result טנזור אפס-ממדי מסוג ui32

דוגמאות

%result = "stablehlo.partition_id">;() : (<) - >tensorui32

 דוגמאות נוספות

popcnt

סמנטיקה

מבצעת ספירה של מספר הביטים שמוגדרים בטנזור operand לפי רכיב, ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]

 דוגמאות נוספות

כוח

סמנטיקה

מבצע העלאה בחזקה של כל אחד מהרכיבים בטנזור lhs בטנזור rhs ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים: העלאה בחזקה של מספר שלם.
  • למספרים ממשיים: pow מ-IEEE-754.
  • למספרים מרוכבים: העלאה בחזקה של מספר מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(power, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)
(I2) rhs טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 דוגמאות נוספות

אמיתי

סמנטיקה

הפונקציה מחלצת את החלק הממשי, לפי רכיב, מ-operand ומפיקה טנסור result. באופן רשמי יותר, לכל רכיב x: real(x) = is_complex(x) ? real_part(x) : x.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מסוג מרוכב (C1), (C2)

פלט

שם סוג מגבלות
result טנסור מסוג נקודה צפה (C1), (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • ‫(C2) element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) if is_complex(operand).
    • element_type(operand) אחרת.

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]

 דוגמאות נוספות

recv

סמנטיקה

מקבל נתונים מערוץ עם channel_id ומפיק results.

אם is_host_transfer הוא true, הפעולה מעבירה נתונים מהמארח. אחרת, הנתונים מועברים ממכשיר אחר על סמך הערכים של source_target_pairs. הסימון הזה משכפל את המידע שמופיע ב-channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (מס' 666). אם is_host_transfer = false ו-source_target_pairs הוא None או ריק, ההתנהגות נחשבת ללא מוגדרת.

results מורכב מערכי מטען ייעודי (payload) שמופיעים ראשונים וטוקן שמופיע אחרון. בעתיד, אנחנו מתכננים לפצל את מטען הנתונים ואת האסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג מגבלות
(I1) token token
(I2) channel_id קבוע מהסוג si64
(I3) channel_type enum של DEVICE_TO_DEVICE ו-DEVICE_TO_HOST (C5)
(I4) is_host_transfer קבוע מהסוג i1 (C5-C6)
(I5) source_target_pairs קבוע טנסור דו-ממדי מסוג si64 (C1-C4), (C6)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה (C2-C4)

מגבלות

  • (C1) dim(source_target_pairs, 1) = 2.
  • ‫(C2) is_unique(source_target_pairs[:, 0]).
  • ‫(C3) is_unique(source_target_pairs[:, 1]).
  • ‫(C4) 0 <= source_target_pairs < N, כאשר N מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • ‫(C5) channel_type מוגדר כך:
    • DEVICE_TO_HOST if is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.

דוגמאות

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)

 דוגמאות נוספות

הקטנה

סמנטיקה

מחיל פונקציית צמצום body על inputs ועל init_values לאורך dimensions ומפיק טנסורים של results.

סדר הפעולות מוגדר בהטמעה, כלומר body ו-init_values חייבים ליצור מונואיד כדי להבטיח שהפעולה תפיק את אותן תוצאות לכל הקלטים בכל ההטמעות. עם זאת, התנאי הזה לא מתקיים בהרבה צמצומים פופולריים. לדוגמה, פעולת החיבור של מספרים ממשיים ב-body ושל אפס ב-init_values לא יוצרת מונואיד, כי פעולת החיבור של מספרים ממשיים היא לא אסוציאטיבית.

באופן רשמי יותר, results...[j0, ..., jR-1] = reduce(input_slices_converted) כאשר:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], כאשר : מוכנסים ב-dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) עבור עץ בינארי מסוים ‫schedule כאשר:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי מלא שמוגדר בהטמעה, והמעבר שלו בסדר עולה כולל:
    • ערכים של input_slices_converted...[index], לכל index ב-index_space(input_slices_converted) בסדר לקסיקוגרפי עולה של index.
    • במרווחים של כמות init_values_converted שמוגדרת בהטמעה, במיקומים שמוגדרים בהטמעה.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C4), (C6), (C7)
(I2) init_values מספר משתנה של טנסורים אפסי-ממדיים או טנסורים שעברו קוונטיזציה לכל טנסור (C2), (C3)
(I3) dimensions קבוע טנסור חד-ממדי מסוג si64 (C4), (C5), (C7)
(I4) body פונקציה (C6)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור ‪(C3), (C7), (C8)

מגבלות

  • (C1) same(shape(inputs...)).
  • ‫(C2) element_type(inputs...) = element_type(init_values...).
  • ‫(C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • ‫(C4) 0 <= dimensions < rank(inputs[0]).
  • ‫(C5) is_unique(dimensions).
  • ‫(C6) body הוא מסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • ‫(C7) shape(results...) = shape(inputs...) למעט גדלי המאפיינים של inputs... שמתאימים ל-dimensions, שלא נכללים.
  • ‫(C8) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
  dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]

 דוגמאות נוספות

reduce_precision

סמנטיקה

מבצע המרה של operand לאלמנטים לסוג אחר של נקודה צפה שמשתמש ב-exponent_bits וב-mantissa_bits, ואז בחזרה לסוג המקורי של נקודה צפה, ומפיק טנסור output.

באופן רשמי יותר:

  • ביטים של המנטיסה של הערך המקורי מעודכנים כדי לעגל את הערך המקורי לערך הקרוב ביותר שאפשר לייצג באמצעות mantissa_bits עם סמנטיקה של roundToIntegralTiesToEven.
  • לאחר מכן, אם mantissa_bits קטן ממספר הביטים של המנטיסה של הערך המקורי, הביטים של המנטיסה נחתכים ל-mantissa_bits.
  • לאחר מכן, אם הביטים של המעריך בתוצאת הביניים לא מתאימים לטווח שצוין על ידי exponent_bits, מתרחש גלישה בתוצאת הביניים לאינסוף באמצעות הסימן המקורי, או גלישה תחתונה לאפס באמצעות הסימן המקורי.
  • לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)
(I2) exponent_bits קבוע מהסוג si32 (C2)
(I3) mantissa_bits קבוע מהסוג si32 (C3)

פלט

שם סוג מגבלות
output טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(output).
  • ‫(C2) 1 <= exponent_bits.
  • ‫(C3) 0 <= mantissa_bits.

דוגמאות

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 דוגמאות נוספות

reduce_scatter

סמנטיקה

reduce_scatter

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מבצעת צמצום באמצעות computations על הערכים של טנסור operand מכל תהליך, מפצלת את תוצאת הצמצום לאורך scatter_dimension לחלקים, ומפיצה את החלקים המפוצלים בין התהליכים כדי ליצור את result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] לכל sender ב-process_group, כאשר receiver_index = process_group.index(receiver).

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension קבוע מהסוג si64 (C1), (C2), (C8)
(I3) replica_groups קבוע טנסור דו-ממדי מסוג si64 (C3-C5)
(I4) channel_id קבוע מהסוג si64 (C6)
(I5) use_global_device_ids קבוע מהסוג i1 (C6)
(I6) computation פונקציה (C7)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C8-C9)

מגבלות

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • ‫(C2) 0 <= scatter_dimension < rank(operand).
  • ‫(C3) is_unique(replica_groups).
  • ‫(C4) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • ‫(C5) 0 <= replica_groups < size(replica_groups).
  • ‫(C6) אם use_global_device_ids = true, אז channel_id > 0.
  • ‫(C7) computation הוא מסוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • ‫(C8) shape(result) = shape(operand) except:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • ‫(C9) element_type(result) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
  %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
  "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimension = 1 :< i64,
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

 דוגמאות נוספות

reduce_window

סמנטיקה

מחיל פונקציית צמצום body על חלונות של inputs ו-init_values ומפיק results.

התרשים הבא מציג דוגמה קונקרטית שממחישה איך מחשבים את הרכיבים ב-results... מתוך inputs....

reduce_window

באופן רשמי יותר, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (ראו reduce) כאשר:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values מספר משתנה של טנסורים אפסי-ממדיים או טנסורים שעברו קוונטיזציה לכל טנסור (C1), (C13)
(I3) window_dimensions קבוע טנסור חד-ממדי מסוג si64 (C4), (C5), (C15)
(I4) window_strides קבוע טנסור חד-ממדי מסוג si64 (C6), (C7), (C15)
(I5) base_dilations קבוע טנסור חד-ממדי מסוג si64 ‫(C8),‏ (C9),‏ (C15)
(I6) window_dilations קבוע טנסור חד-ממדי מסוג si64 ‫(C10),‏ (C11),‏ (C15)
(I7) padding קבוע טנסור דו-ממדי מסוג si64 (C12), (C15)
(I8) body פונקציה (C13)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1), (C14-C16)

מגבלות

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • ‫(C2) same(shape(inputs...)).
  • ‫(C3) element_type(inputs...) = element_type(init_values...).
  • ‫(C4) size(window_dimensions) = rank(inputs[0]).
  • ‫(C5) 0 < window_dimensions.
  • ‫(C6) size(window_strides) = rank(inputs[0]).
  • ‫(C7) 0 < window_strides.
  • ‫(C8) size(base_dilations) = rank(inputs[0]).
  • ‫(C9) 0 < base_dilations.
  • ‫(C10) size(window_dilations) = rank(inputs[0]).
  • ‫(C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • ‫(C13) body הוא מסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • ‫(C15) shape(results[0]) = num_windows כאשר:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • ‫(C16) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
  wind>ow_dimensions = arrayi64: <2, 1,
  w>indow_strides = arrayi64: <4, 1,
  b>ase_dilations = arrayi64: 2,< 1,
  win>dow_dilations = arr<ayi64: 3, 1,
  p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]

 דוגמאות נוספות

שארית

סמנטיקה

הפונקציה מחשבת את השארית של כל אחד מהרכיבים של טנזור המחולק lhs וטנזור המחלק rhs, ומחזירה טנזור result.

באופן רשמי יותר, הסימן של התוצאה נלקח מהמחולק, והערך המוחלט של התוצאה תמיד קטן מהערך המוחלט של המחלק. השארית מחושבת כך: lhs - d * rhs, כאשר d מחושב כך:

  • למספרים שלמים: stablehlo.divide(lhs, rhs).
  • למספרים ממשיים: division(lhs, rhs) מ-IEEE-754 עם מאפיין עיגול roundTowardZero.
  • למספרים מרוכבים: ייקבע בהמשך (‎#997).
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

במקרה של סוגי רכיבים של נקודה צפה, הפעולה הזו שונה מהפעולה remainder שמוגדרת במפרט IEEE-754, שבה d הוא ערך שלם הכי קרוב לערך המדויק של lhs/rhs, וההכרעה במקרה של שוויון היא לטובת המספר הזוגי.

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)
(I2) rhs טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]

 דוגמאות נוספות

replica_id

סמנטיקה

יוצרת replica_id של התהליך הנוכחי.

פלט

שם סוג
result טנזור אפס-ממדי מסוג ui32

דוגמאות

%result = "stablehlo.replica_id">;() : (<) - >tensorui32

 דוגמאות נוספות

עיצוב מחדש

סמנטיקה

מבצעת שינוי צורה של טנסור operand לטנסור result. מבחינה רעיונית, המשמעות היא לשמור על אותו ייצוג קנוני אבל לשנות את הצורה, למשל מ-tensor<2x3xf32> ל-tensor<3x2xf32> או ל-tensor<6xf32>.

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר result_index ו-operand_index נמצאים באותו מיקום בסדר הלקסיקוגרפי של index_space(result) ו-index_space(operand).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C3)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1-C3)

מגבלות

  • ‫(C1) element_type(result) מחושב לפי הנוסחה הבאה:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) חוץ מהמקרים שבהם quantization_dimension(operand) וquantization_dimension(result) שונים.
  • ‫(C2) size(operand) = size(result).
  • ‫(C3) אם is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]

 דוגמאות נוספות

הפוך

סמנטיקה

הפונקציה הופכת את סדר הרכיבים ב-operand לאורך dimensions שצוין ומפיקה טנזור result. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 if d in dimensions.
  • operand_index[d] = result_index[d] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1), (C3)
(I2) dimensions קבוע טנסור חד-ממדי מסוג si64 (C2), (C3)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1), (C3)

מגבלות

  • (C1) type(operand) = type(result).
  • ‫(C2) is_unique(dimensions).
  • ‫(C3) 0 <= dimensions < rank(result).

דוגמאות

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]

 דוגמאות נוספות

rng

סמנטיקה

יוצרת מספרים אקראיים באמצעות אלגוריתם rng_distribution ומפיקה טנזור result בצורה נתונה shape.

אם rng_distribution = UNIFORM, המספרים האקראיים נוצרים לפי התפלגות אחידה במרווח [a, b). אם a >= b, ההתנהגות לא מוגדרת.

אם הערך הוא rng_distribution = NORMAL, המספרים האקראיים נוצרים לפי ההתפלגות הנורמלית עם ממוצע = a וסטיית תקן = b. אם b < 0, ההתנהגות לא מוגדרת.

הדרך המדויקת שבה נוצרים מספרים אקראיים מוגדרת בהטמעה. לדוגמה, יכול להיות שהם דטרמיניסטיים או לא, ויכול להיות שהם משתמשים במצב מוסתר או לא.

בשיחות עם בעלי עניין רבים, עלה שהאופציה הזו הוצאה משימוש בפועל, ולכן בעתיד אנחנו מתכננים לבדוק את האפשרות להסיר אותה (#597).

קלט

תווית שם סוג מגבלות
(I1) a טנזור אפס-ממדי מסוג מספר שלם, בוליאני או נקודה צפה (C1), (C2)
(I2) b טנזור אפס-ממדי מסוג מספר שלם, בוליאני או נקודה צפה (C1), (C2)
(I3) shape קבוע טנסור חד-ממדי מסוג si64 (C3)
(I4) rng_distribution enum of UNIFORM and NORMAL (C2)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם, בוליאני או נקודה צפה (C1-C3)

מגבלות

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • ‫(C2) אם rng_distribution = NORMAL, אז is_float(a).
  • ‫(C3) shape(result) = shape.

דוגמאות

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

סמנטיקה

הפונקציה מחזירה output מלא בביטים אקראיים אחידים ומצב פלט מעודכן output_state באמצעות האלגוריתם של מחולל מספרים פסאודו-אקראיים rng_algorithm, בהינתן מצב התחלתי initial_state. הפלט הוא פונקציה דטרמיניסטית של initial_state, אבל לא מובטח שהוא יהיה דטרמיניסטי בין יישומים שונים.

rng_algorithm הוא אחד מהבאים:

  • DEFAULT: אלגוריתם שמוגדר בהטמעה.
  • THREE_FRY: וריאנט של אלגוריתם Threefry שמוגדר בהטמעה.*
  • PHILOX: וריאציה של אלגוריתם Philox שמוגדרת בהטמעה.*

* ראו: Salmon et al. SC 2011. מספרים אקראיים מקבילים: קל ופשוט כמו 1, 2, 3.

קלט

תווית שם סוג מגבלות
(I1) rng_algorithm enum של DEFAULT,‏ THREE_FRY ו-PHILOX (C2)
(I2) initial_state טנזור חד-ממדי מסוג ui64 (C1), (C2)

פלט

שם סוג מגבלות
output_state טנזור חד-ממדי מסוג ui64 (C1)
output טנסור מסוג מספר שלם או מספר בשיטת נקודה צפה

מגבלות

  • (C1) type(initial_state) = type(output_state).
  • ‫(C2) size(initial_state) מוגדר כך:
    • מוגדרת בהטמעה אם rng_algorithm = DEFAULT.
    • 2 if rng_algorithm = THREE_FRY.
    • 2 או 3 אם rng_algorithm = PHILOX.

דוגמאות

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

סמנטיקה

מעגלת כל רכיב בטנזור operand למספר השלם הקרוב ביותר, ומעגלת מספרים שהם בדיוק באמצע בין שני מספרים שלמים לכיוון הרחוק מאפס. התוצאה היא טנזור result. מטמיע את הפעולה roundToIntegralTiesToAway מהמפרט IEEE-754. עבור סוגים שעברו קוונטיזציה, הפעולה dequantize_op_quantize(round_nearest_afz, operand, type(result)) מתבצעת.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

 דוגמאות נוספות

round_nearest_even

סמנטיקה

הפונקציה מבצעת עיגול של כל רכיב בטנזור operand למספר השלם הקרוב ביותר, ומעגלת מספרים שהם בדיוק באמצע בין שני מספרים שלמים למספר השלם הזוגי. התוצאה היא טנזור result. מיישמת את הפעולה roundToIntegralTiesToEven מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(round_nearest_even, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 דוגמאות נוספות

rsqrt

סמנטיקה

מבצעת פעולת שורש ריבועי הדדי על טנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: rSqrt מ-IEEE-754.
  • למספרים מרוכבים: שורש ריבועי הפוך מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(rsqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 דוגמאות נוספות

פיזור

סמנטיקה

הפונקציה יוצרת טנסורים results ששווים לטנסורים inputs, אלא שכמה פרוסות שצוינו על ידי scatter_indices מתעדכנות עם הערכים updates באמצעות update_computation.

בתרשים הבא מוצג מיפוי של רכיבים ב-updates... לרכיבים ב-results... באמצעות דוגמה קונקרטית. בתרשים מוצגים כמה מדדים לדוגמה של updates... ומפורט לאילו מדדים של results... הם מתאימים.

פיזור

באופן רשמי יותר, לכל update_index ב-index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • ההגדרה של start_index היא:
    • scatter_indices[si0, ..., :, ..., siN] כאשר si הם רכיבים נפרדים ב-update_scatter_index ו-: מוכנס באינדקס index_vector_dim, אם index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] אחרת.
  • d_input ב-axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 אחרת.
  • d_input ב-axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] אם d_input = input_batching_dims[i_batching] וגם d_start = scatter_indices_batching_dims[i_batching].
    • full_batching_index[d_input] = 0 אחרת.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] כאשר wi הם רכיבים נפרדים ב-update_window_index, ו-0 מוכנס באינדקסים מ-inserted_window_dims ועד input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

בהתאם לכך, results = exec(schedule, inputs), כאשר:

  • schedule הוא תמורה שמוגדרת בהטמעה של index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) כאשר:
    • אם result_index נמצא בטווח של shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results הוא עותק של results עם results...[result_index] שמוגדר כ-updated_values....
    • אחרת
    • updated_results = results.
  • exec([], results) = results.

אם indices_are_sorted הוא true, ההטמעה יכולה להניח ש-scatter_indices ממוינים לפי scatter_dims_to_operand_dims, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result), ‏ full_start_index(i1) <= full_start_index(i2).

אם הערך של unique_indices הוא true, ההטמעה יכולה להניח שכל האינדקסים של result_index שמופצים הם ייחודיים. אם הערך של unique_indices הוא true אבל האינדקסים שאליהם מתבצע הפיזור לא ייחודיים, ההתנהגות לא מוגדרת.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices טנזור מסוג מספר שלם (C4), (C15), (C19), (C22)
(I3) updates מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C3-C6), (C8)
(I4) update_window_dims קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims קבוע טנסור חד-ממדי מסוג si64 (C14-C18)
(I8) scatter_dims_to_operand_dims קבוע טנסור חד-ממדי מסוג si64 (C19-C21)
(I9) index_vector_dim קבוע מהסוג si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted קבוע מהסוג i1
(I11) unique_indices קבוע מהסוג i1
(I12) update_computation פונקציה (C23)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C24-C25)

מגבלות

  • (C1) same(shape(inputs...)).
  • ‫(C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims).
  • ‫(C3) same(shape(updates...)).
  • ‫(C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) כאשר:
    • update_scatter_dim_sizes = shape(scatter_indices), אלא שגודל המימד של scatter_indices שמתאים ל-index_vector_dim לא נכלל.
    • update_window_dim_sizes <= shape(inputs[0]), רק שגודלי המאפיינים ב-inputs[0] שמתאימים ל-inserted_window_dims ול-input_batching_dims לא נכללים.
    • הפונקציה combine ממקמת את update_scatter_dim_sizes בצירים שתואמים ל-update_scatter_dims ואת update_window_dim_sizes בצירים שתואמים ל-update_window_dims.
  • ‫(C5) 0 < size(inputs) = size(updates) = N.
  • ‫(C6) element_type(updates...) = element_type(inputs...).
  • ‫(C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • ‫(C8) 0 <= update_window_dims < rank(updates[0]).
  • ‫(C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • ‫(C10) is_sorted(inserted_window_dims).
  • ‫(C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • ‫(C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • ‫(C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • ‫(C16) index_vector_dim not in scatter_indices_batching_dims.
  • ‫(C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • ‫(C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • ‫(C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • ‫(C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • ‫(C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • ‫(C23) update_computation הוא מסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), כאשר is_promotable(element_type(inputs[i]), Ei).
  • ‫(C24) shape(inputs...) = shape(results...).
  • ‫(C25) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ],
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimensio<n_numbers = #stablehlo.scatter
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2>, 1],
    index_vector_dim = 3,
  indices_are_sorted = false,
  uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 דוגמאות נוספות

בחירה

סמנטיקה

הפונקציה יוצרת טנסור result שבו כל רכיב נבחר מתוך טנסור on_true או on_false על סמך הערך של הרכיב התואם בטנסור pred. באופן רשמי יותר, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], כאשר pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. לסוגים כמותיים, הפונקציה מבצעת dequantize_select_quantize(pred, on_true, on_false, type(result)).

קלט

תווית שם סוג מגבלות
(I1) pred טנסור מסוג i1 (C1)
(I2) on_true tensor או per-tensor quantized tensor (C1-C2)
(I3) on_false tensor או per-tensor quantized tensor (C2)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C2)

מגבלות

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • ‫(C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

דוגמאות

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]

 דוגמאות נוספות

select_and_scatter

סמנטיקה

הפעולה מפזרת את הערכים מטנסור source באמצעות scatter על סמך התוצאה של reduce_window של טנסור input באמצעות select, ומפיקה טנסור result.

בתרשים הבא מוצגות דוגמאות קונקרטיות שממחישות איך מחשבים את הרכיבים ב-result מ-operand ומ-source.

select_and_scatter

באופן רשמי יותר:

  • selected_values = reduce_window_without_init(...) עם הקלט הבא:

    • inputs = [operand].
    • window_dimensions, ‏window_strides ו-padding שמשמשים כמו שהם.
    • base_dilations = windows_dilations = 1.
    • ההגדרה של body היא:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    כאשר E = element_type(operand) ו-reduce_window_without_init פועלים בדיוק כמו reduce_window, אלא ש-schedule של reduce הבסיסי (ראו reduce) לא כולל ערכי init. בשלב הזה לא ברור מה קורה אם בחלון המתאים אין ערכים (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) where:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index if ‫selected_values[source_index] has the operand element ‫from operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1-C4), (C6), (C8-C11)
(I2) source tensor או per-tensor quantized tensor (C1), (C2)
(I3) init_value טנזור אפס-ממדי או טנזור שעבר קוונטיזציה לכל טנזור (C3)
(I4) window_dimensions קבוע טנסור חד-ממדי מסוג si64 (C2), (C4), (C5)
(I5) window_strides קבוע טנסור חד-ממדי מסוג si64 (C2), (C6), (C7)
(I6) padding קבוע טנסור דו-ממדי מסוג si64 (C2), (C8)
(I7) select פונקציה (C9)
(I8) scatter פונקציה (C10)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor ‫(C11-C12)

מגבלות

  • (C1) element_type(operand) = element_type(source).
  • ‫(C2) shape(source) = num_windows כאשר:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • ‫(C3) element_type(init_value) = element_type(operand).
  • ‫(C4) size(window_dimensions) = rank(operand).
  • ‫(C5) 0 < window_dimensions.
  • ‫(C6) size(window_strides) = rank(operand).
  • ‫(C7) 0 < window_strides.
  • ‫(C8) shape(padding) = [rank(operand), 2].
  • ‫(C9) select הוא מסוג (tensor<E>, tensor<E>) -> tensor<i1> כאשר E = element_type(operand).
  • ‫(C10) scatter הוא מסוג (tensor<E>, tensor<E>) -> tensor<E> כאשר is_promotable(element_type(operand), E).
  • ‫(C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

דוגמאות

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>E
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
  ^bb0(%<arg>0: tensori64, %arg1: tensori64):
    %0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
  >  "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
  window_dim<ensions => arrayi64: 3, 1,
  <window_strides => arrayi64<: 2, 1,>
  padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

 דוגמאות נוספות

שליחה

סמנטיקה

שליחת inputs לערוץ channel_id. הקלט נשלח למכשירים אחרים לפי הסדר שצוין ב-source_target_pairs. הפעולה יוצרת טוקן result.

אם is_host_transfer הוא true, הפעולה מעבירה נתונים למארח. אחרת, הנתונים מועברים למכשיר אחר על סמך הערכים של source_target_pairs. הסימון הזה משכפל את המידע שמופיע ב-channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (מס' 666). אם is_host_transfer = false ו-source_target_pairs הוא None או ריק, ההתנהגות נחשבת ללא מוגדרת.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה
(I2) token token
(I3) channel_id קבוע מהסוג si64
(I4) channel_type enum של DEVICE_TO_DEVICE ו-DEVICE_TO_HOST (C5)
(I5) is_host_transfer קבוע מהסוג i1 (C5-C6)
(I6) source_target_pairs קבוע טנסור דו-ממדי מסוג si64 (C1-C4), (C6)

פלט

שם סוג
result token

מגבלות

  • (C1) dim(source_target_pairs, 1) = 2.
  • ‫(C2) is_unique(source_target_pairs[:, 0]).
  • ‫(C3) is_unique(source_target_pairs[:, 1]).
  • ‫(C4) 0 <= source_target_pairs < N, כאשר N מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • ‫(C5) channel_type מוגדר כך:
    • DEVICE_TO_HOST if is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.

דוגמאות

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token

 דוגמאות נוספות

shift_left

סמנטיקה

מבצעת פעולת הזזה שמאלה ברמת הרכיב בטנזור lhs במספר rhs של סיביות ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור מסוג מספר שלם (C1)
(I2) rhs טנזור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]

 דוגמאות נוספות

shift_right_arithmetic

סמנטיקה

מבצעת פעולת הזזה אריתמטית ימינה ברמת הרכיב בטנזור lhs לפי מספר הסיביות rhs ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור מסוג מספר שלם (C1)
(I2) rhs טנזור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]

 דוגמאות נוספות

shift_right_logical

סמנטיקה

מבצעת פעולת הזזה לוגית ימינה ברמת הרכיב בטנזור lhs במספר rhs של סיביות ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור מסוג מספר שלם (C1)
(I2) rhs טנזור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טנזור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]

 דוגמאות נוספות

סמל

סמנטיקה

הפונקציה מחזירה את הסימן של כל רכיב בטנזור operand ומפיקה טנזור result. באופן רשמי יותר, אפשר לבטא את הסמנטיקה של כל רכיב x באמצעות תחביר Python באופן הבא:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(sign, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 דוגמאות נוספות

סינוס

סמנטיקה

מבצעת פעולת סינוס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: sin מ-IEEE-754.
  • למספרים מרוכבים: סינוס מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(sine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]

 דוגמאות נוספות

פרוסה (slice)

סמנטיקה

מחזירה פרוסה מ-operand באמצעות אינדקסים התחלתיים שמחושבים באופן סטטי ומפיקה טנזור result. ‫start_indices מכיל את אינדקס ההתחלה של הפרוסה לכל מאפיין, limit_indices מכיל את אינדקס הסיום (לא כולל) של הפרוסה לכל מאפיין, ו-strides מכיל את הצעדים לכל מאפיין.

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר operand_index = start_indices + result_index * strides.

קלט

תווית שם סוג מגבלות
(I1) operand tensor או per-tensor quantized tensor (C1-C3), (C5)
(I2) start_indices קבוע טנסור חד-ממדי מסוג si64 (C2), (C3), (C5)
(I3) limit_indices קבוע טנסור חד-ממדי מסוג si64 (C2), (C3), (C5)
(I4) strides קבוע טנסור חד-ממדי מסוג si64 (C2), (C4)

פלט

שם סוג מגבלות
result tensor או per-tensor quantized tensor (C1), (C5)

מגבלות

  • (C1) element_type(operand) = element_type(result).
  • ‫(C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • ‫(C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • ‫(C4) 0 < strides.
  • ‫(C5) shape(result) = ceil((limit_indices - start_indices) / strides).

דוגמאות

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indic<es = arra>yi64: 1, 2,
  limit_indic<es = arra>yi64: 3, 4,
  strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 דוגמאות נוספות

מיון

סמנטיקה

ממיינת פרוסות חד-ממדיות של inputs לאורך המימד dimension יחד, בהתאם ל-comparator ומפיקה results.

בניגוד לקלט דומה בפעולות אחרות, dimension מאפשרת ערכים שליליים, עם הסמנטיקה שמתוארת בהמשך. יכול להיות שבעתיד לא נאפשר את זה מסיבות של עקביות (#1377).

אם is_stable הוא true, המיון יציב, כלומר הסדר היחסי של הרכיבים שנחשבים שווים על ידי פונקציית ההשוואה נשמר. במקרה שיש קלט יחיד, שני האלמנטים e1 ו-e2 נחשבים שווים על ידי הפונקציה להשוואה אם ורק אם comparator(e1, e2) = comparator(e2, e1) = false. בהמשך מוסבר איך להכליל את זה למספר קלטים.

באופן רשמי יותר, לכל result_index ב-index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] כשriN הם רכיבים נפרדים ב-result_index, ו-: מוכנס ב-adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • כאשר sort ממיין פרוסה חד-ממדית בסדר לא יורד, ומצפה ש-comparator_together יחזיר true אם הארגומנט בצד ימין קטן מהארגומנט השני בצד שמאל.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C1-C5)
(I2) dimension קבוע מהסוג si64 ‫(C4)
(I3) is_stable קבוע מהסוג i1
(I4) comparator פונקציה (C5)

פלט

שם סוג מגבלות
results מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור (C2), (C3)

מגבלות

  • (C1) 0 < size(inputs).
  • ‫(C2) type(inputs...) = type(results...).
  • ‫(C3) same(shape(inputs...) + shape(results...)).
  • ‫(C4) -R <= dimension < R, כשR = rank(inputs[0]).
  • ‫(C5) comparator הוא מסוג (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, כאשר Ei = element_type(inputs[i]).

דוגמאות

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
  dimension = 0 : i64,
<  is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 דוגמאות נוספות

sqrt

סמנטיקה

מבצעת פעולת שורש ריבועי על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: squareRoot מ-IEEE-754.
  • למספרים מרוכבים: שורש ריבועי מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(sqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]

 דוגמאות נוספות

חיסור

סמנטיקה

מבצעת חיסור של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים: חיסור מספרים שלמים.
  • למספרים ממשיים: subtraction מ-IEEE-754.
  • למספרים מרוכבים: חיסור מרוכב.
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)
(I2) rhs טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]

 דוגמאות נוספות

tan

סמנטיקה

מבצעת פעולת טנגנס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: tan מ-IEEE-754.
  • למספרים מרוכבים: טנגנס מרוכב.
  • לסוגים כמותיים: dequantize_op_quantize(tan, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 דוגמאות נוספות

tanh

סמנטיקה

מבצעת פעולת טנגנס היפרבולי לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים ממשיים: tanh מ-IEEE-754.
  • למספרים מרוכבים: טנגנס היפרבולי מרוכב.
  • לסוגים שעברו קוונטיזציה:
    • dequantize_op_quantize(tanh, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]

 דוגמאות נוספות

החלפה

סמנטיקה

מבצעת פרמוטציה של המימדים של טנזור operand באמצעות permutation ומפיקה טנזור result. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר result_index[d] = operand_index[permutation[d]].

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או טנזור שעבר קוונטיזציה (C1-C4)
(I2) permutation קבוע טנסור חד-ממדי מסוג si64 (C2-C4)

פלט

שם סוג מגבלות
result טנזור או טנזור שעבר קוונטיזציה (C1), (C3-C4)

מגבלות

  • ‫(C1) element_type(result) מחושב לפי הנוסחה הבאה:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) חוץ מהמקרים שבהם quantization_dimension(operand) וquantization_dimension(result) שונים.
  • ‫(C2) permutation היא תמורה של range(rank(operand)).
  • ‫(C3) shape(result) = dim(operand, permutation...).
  • ‫(C4) אם is_per_axis_quantized(result), אז quantization_dimension(operand) = permutation(quantization_dimension(result)).

דוגמאות

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 דוגמאות נוספות

triangular_solve

סמנטיקה

פותרים קבוצות של מערכות של משוואות ליניאריות עם מטריצות מקדמים משולשות עליונות או תחתונות.

באופן פורמלי יותר, בהינתן a ו-b, ‏ result[i0, ..., iR-3, :, :] הוא הפתרון של op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] כאשר left_side הוא true או x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] כאשר left_side הוא false, ופתרון המשתנה x כאשר op(a) נקבע על ידי transpose_a, שיכול להיות אחד מהבאים:

  • NO_TRANSPOSE: ביצוע הפעולה באמצעות a כמו שהוא.
  • TRANSPOSE: ביצוע פעולה על טרנספוזיציה של a.
  • ADJOINT: ביצוע פעולה על המטריצה המשוחפת של a.

נתוני הקלט נקראים רק מהמשולש התחתון של a, אם lower הוא true או מהמשולש העליון של a, אחרת. נתוני הפלט מוחזרים באותו משולש. הערכים במשולש השני מוגדרים על ידי ההטמעה.

אם unit_diagonal הוא true, ההטמעה יכולה להניח שהאלמנטים האלכסוניים של a שווים ל-1, אחרת ההתנהגות לא מוגדרת.

לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1-C3)
(I2) b טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1-C4)
(I3) left_side קבוע מהסוג i1 (C3)
(I4) lower קבוע מהסוג i1
(I5) unit_diagonal קבוע מהסוג i1
(I6) transpose_a enum של NO_TRANSPOSE,‏ TRANSPOSE ו-ADJOINT

פלט

שם סוג מגבלות
result טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור (C1)

מגבלות

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • ‫(C2) 2 <= rank(a) = rank(b) = R.
  • ‫(C3) הקשר בין shape(a) ל-shape(b) מוגדר באופן הבא:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • ‫(C4) baseline_type(b) = baseline_type(result).

דוגמאות

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

סמנטיקה

הפונקציה יוצרת טאפל result מהערכים val.

קלט

תווית שם סוג מגבלות
(I1) val מספר משתנה של ערכים (C1)

פלט

שם סוג מגבלות
result tuple (C1)

מגבלות

  • ‫(C1) result הוא מסוג tuple<E0, ..., EN-1> כאשר Ei = type(val[i]).

דוגמאות

// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))

 דוגמאות נוספות

uniform_dequantize

סמנטיקה

הפונקציה מבצעת המרה של טנסור עם כימות operand לטנסור עם נקודה צפה result לפי פרמטרי הכימות שמוגדרים בסוג operand.

באופן רשמי יותר, result = dequantize(operand).

קלט

תווית שם סוג מגבלות
(I1) operand טנסור שעבר קוונטיזציה (C1), (C2)

פלט

שם סוג מגבלות
result טנסור מסוג נקודה צפה (C1), (C2)

מגבלות

  • (C1) shape(operand) = shape(result).
  • ‫(C2) element_type(result) = expressed_type(operand).

דוגמאות

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]

uniform_quantize

סמנטיקה

מבצע המרה של טנזור של נקודה צפה או טנזור שעבר קוונטיזציה operand לטנזור שעבר קוונטיזציה result לפי פרמטרי הקוונטיזציה שהוגדרו על ידי הסוג result.

באופן רשמי יותר,

  • אם is_float(operand):
    • result = quantize(operand, type(result)).
  • אם is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טנסור מסוג נקודה צפה או מסוג כמותי (C1), (C2)

פלט

שם סוג מגבלות
result טנסור שעבר קוונטיזציה (C1), (C2)

מגבלות

  • (C1) shape(operand) = shape(result).
  • ‫(C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

דוגמאות

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]

בזמן ש

סמנטיקה

מפיקה את הפלט מהפעלת הפונקציה body 0 או יותר פעמים בזמן שהפונקציה cond מפיקה את הפלט true. באופן רשמי יותר, אפשר לבטא את הסמנטיקה באמצעות תחביר Python באופן הבא:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

ההתנהגות של לולאה אינסופית תיקבע בהמשך (#383).

קלט

תווית שם סוג מגבלות
(I1) operand מספר משתנה של ערכים (C1-C3)
(I2) cond פונקציה (C1)
(I3) body פונקציה (C2)

פלט

שם סוג מגבלות
results מספר משתנה של ערכים (C3)

מגבלות

  • ‫(C1) cond הוא מסוג (T0, ..., TN-1) -> tensor<i1>, כאשר Ti = type(operand[i]).
  • ‫(C2) body הוא מסוג (T0, ..., TN-1) -> (T0, ..., TN-1), כאשר Ti = type(operand[i]).
  • ‫(C3) type(results...) = type(operand...).

דוגמאות

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_di<rection = #stablehlocom>parison_directio<n L>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    stablehlo.r<et>urn %cond : tensori1
  }, {
<  ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
    %new_sum = stablehlo.add <%ar>g1, %one : tensori64
    %new_i = stablehlo.add <%ar>g0, %one : tensori64
    stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10

 דוגמאות נוספות

xor

סמנטיקה

מבצעת XOR ברמת הרכיבים של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • עבור ערכים בוליאניים: XOR לוגי.
  • למספרים שלמים: ערך XOR ברמת הביטים.

קלט

תווית שם סוג מגבלות
(I1) lhs טנסור מסוג בוליאני או מספר שלם (C1)
(I2) rhs טנסור מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result טנסור מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]

 דוגמאות נוספות

פעולה הדדית בין ניבים

בשלב הזה, תוכניות StableHLO בשימוש נרחב מכילות לפעמים פעולות שלא מוגדרות על ידי StableHLO.

Module, Function, Call and Return

‫StableHLO משתמש בפעולות MLIR במעלה הזרם עבור ModuleOp,‏ FuncOp,‏ CallOp ו-ReturnOp. השינוי הזה בוצע כדי לשפר את יכולת הפעולה ההדדית עם מנגנון MLIR קיים, כי הרבה מעברים שימושיים נכתבים עם טירגוט של FuncOp ו-ModuleOp, והרבה צינורות של קומפילציה מצפים שהאופרטורים האלה יהיו נוכחים. התאימות המלאה מובטחת לפעולות האלה. אם יחולו שינויים בפעולות האלה באופן לא תואם (כלומר, הסרה), יתווספו מקבילות של StableHLO כדי לשמור על התאימות.

CHLO

קבוצת הפעולות של CHLO מכילה פעולות ברמה גבוהה יותר שמפורקות ל-StableHLO. נכון לעכשיו, אין התחייבות לתאימות של CHLO. כדי להבטיח תאימות, צריך להשתמש בהעברת chlo-legalize-to-stablehlo לפני הסריאליזציה.

פעולות על צורות

מקרה שימוש נפוץ בקהילה הוא שימוש בפעולות מסוימות מניבים בסיסיים של MLIR בתוכניות דינמיות של StableHLO כדי לבצע חישובים של צורות. בדרך כלל, אלה כוללים פעולות של shape דיאלקט כמו shape_of או num_elements, פעולות של tensor דיאלקט כמו dim או from_elements, וסוג ה-build-in index.

במסמך Dynamism RFC > O2 הסוגים האלה מוגדרים כלא רלוונטיים, אבל יש תמיכה מסוימת בסוגים של index לצורך פעולה הדדית. אין התחייבות לתאימות של הפעולות או הסוגים האלה. אפשר להשתמש במעבר shape-legalize-to-stablehlo כדי להמיר את הפעולות האלה לפעולות StableHLO נתמכות באופן מלא.

פעולות שהוצאו משימוש

יש כמה פעולות StableHLO שהתקבלו בירושה מ-MHLO, שהוצאו משימוש ועומדות להיעלם מ-StableHLO. הפרטים המלאים על ההסרות האלה זמינים בניקוי StableHLO v1.0‏ #2283. מספר הבעיה בכלי המעקב לגבי הוצאה משימוש של התכונות האלה הוא 2340.

הפעולות האלה נחלקות לכמה קטגוריות:

  • הקטגוריה 'לא ב-HLO' של פעולות StableHLO – הפעולות האלה היו במקור חלק מה-opset של StableHLO, אבל בהמשך נקבע שהן לא מתאימות לו: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (מס' 3).
  • פעולות לא בשימוש – יכול להיות שהפעולות האלה היו שימושיות בשלב מסוים, אבל הן לא פותחו מספיק או שהצינורות שמשתמשים בהן עברו שינוי מבנה כך שהן כבר לא נדרשות. השינויים כוללים את map, tuple (#598), השוואות get_tuple_element, rng, complex (#560), וקונבולוציה window_reversal (#1181).

חלק מהפעולות האלה אפשר להסיר בקלות, כי אפשר לבטא אותן באמצעות פעולות קיימות (broadcast,‏ create_token,‏ cross-replica-sum,‏ dot,‏ unary_einsum), והן יוסרו אחרי שתקופת התאימות הקיימת תסתיים (6 חודשים). אנחנו עדיין בודקים אפשרות להסיר פעולות אחרות (einsum,‏ get_tuple_element,‏ map,‏ rng torch_index_select,‏ tuple,‏ complex השוואות, window_reversal). בהתאם למשוב מהקהילה, הפעולות האלה יוסרו או יתווספו למפרט עם תמיכה מלאה. עד שיידעו מה יהיו הגרסאות העתידיות האלה, מובטחת תאימות רק ל-6 חודשים.

ביצוע

ביצוע רציף

כדי להריץ תוכנית StableHLO, צריך לספק ערכי קלט לפונקציה main ולחשב את ערכי הפלט. ערכי הפלט של פונקציה מחושבים על ידי הפעלת הגרף של פעולות שמושרשות בפעולת return המתאימה.

סדר ההפעלה מוגדר בהטמעה, כל עוד הוא תואם לזרימת הנתונים, כלומר אם פעולות מופעלות לפני השימוש בהן. ב-StableHLO, כל פעולות ה-side-effecting צורכות טוקן אחד ומפיקות טוקן אחד (אפשר לבצע מולטיפלקס של כמה טוקנים לטוקן אחד באמצעות after_all), ולכן סדר הביצוע של side-effects תואם גם לזרימת הנתונים. לדוגמה, בתוכנית שלמטה יש שני סדרי ביצוע אפשריים: %0%1%2return ו-%1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

באופן רשמי יותר, תהליך StableHLO הוא שילוב של: ‫1) תוכנית StableHLO,‏ 2) סטטוסים של פעולות (עדיין לא בוצעו, כבר בוצעו) ו-3) ערכי ביניים שהתהליך פועל עליהם. התהליך מתחיל בערכי קלט לפונקציה main, ממשיך בתרשים של פעולות שמעדכנות את סטטוסי הפעולות וערכי הביניים, ומסתיים בערכי פלט. המשך הפורמליזציה ייקבע בהמשך (#484).

ביצוע מקביל

אפשר להריץ תוכניות StableHLO במקביל, והן מאורגנות ברשת תהליכים דו-ממדית של num_replicas על num_partitions, ששניהם הם מסוג ui32.

ברשת התהליכים של StableHLO, ‏num_replicas * num_partitions תהליכים של StableHLO פועלים בו-זמנית. לכל תהליך יש process_id = (replica_id, partition_id) ייחודי, כאשר replica_id ב-replica_ids = range(num_replicas) ו-partition_id ב-partition_ids = range(num_partitions), ולשניהם יש סוג ui32.

גודל רשת התהליכים ידוע באופן סטטי לכל תוכנית (בעתיד, אנחנו מתכננים להפוך אותו לחלק מפורש בתוכניות StableHLO #650), והמיקום ברשת התהליכים ידוע באופן סטטי לכל תהליך. לכל תהליך יש גישה למיקום שלו ברשת התהליכים באמצעות הפעולות replica_id ו-partition_id.

ברשת התהליכים, כל התוכניות יכולות להיות זהות (בסגנון Single Program, Multiple Data), כולן יכולות להיות שונות (בסגנון Multiple Program, Multiple Data) או משהו באמצע. בעתיד, אנחנו מתכננים להוסיף תמיכה בניבים אחרים להגדרת תוכניות מקבילות של StableHLO, כולל GSPMD (‎#619).

בתוך רשת התהליכים, התהליכים הם ברובם בלתי תלויים זה בזה – יש להם סטטוסי פעולה נפרדים, ערכי קלט/ביניים/פלט נפרדים ורוב הפעולות מבוצעות בנפרד בין התהליכים, למעט מספר קטן של פעולות קולקטיביות שמתוארות בהמשך.

בהתחשב בכך שרוב הפעולות משתמשות רק בערכים מאותו תהליך, בדרך כלל ברור למה מתייחסים כשמציינים את השמות של הערכים האלה. עם זאת, כשמתארים סמנטיקה של פעולות קולקטיביות, זה לא מספיק, ולכן משתמשים בסימון name@process_id כדי להתייחס לערך name בתהליך מסוים. (מנקודת המבט הזו, אפשר לראות ב-name לא מתאים קיצור של name@(replica_id(), partition_id())).

סדר הביצוע בין התהליכים מוגדר בהטמעה, למעט הסנכרון שנוצר על ידי תקשורת ישירה ופעולות קולקטיביות, כפי שמתואר בהמשך.

תקשורת מנקודה לנקודה

תהליכי StableHLO יכולים לתקשר זה עם זה באמצעות ערוצי StableHLO. ערוץ מיוצג על ידי מזהה חיובי מסוג si64. באמצעות פעולות שונות, אפשר לשלוח ערכים לערוצים ולקבל אותם מהערוצים.

צריך להגדיר באופן רשמי יותר, למשל מאיפה מגיעים מזהי הערוצים האלה, איך תוכנות תהליכים מודעות להם ואיזה סוג של סנכרון מוצג על ידם. הנושא הזה נמצא בבדיקה (#484).

שידור תקשורת

לכל תהליך StableHLO יש גישה לשני ממשקי סטרימינג:

  • בפיד שאפשר לקרוא ממנו.
  • Outfeed שאפשר לכתוב אליו.

בניגוד לערוצים, שמשמשים לתקשורת בין תהליכים ולכן יש להם תהליכים בשני הקצוות שלהם, ל-infeed ול-outfeed יש קצה אחר שמוגדר על ידי היישום.

המשך הפורמליזציה, למשל איך תקשורת בסטרימינג משפיעה על סדר הביצוע ואיזה סוג של סנכרון מוצג על ידי זה, יפורט בהמשך (#484).

תפעול קולקטיבי

יש שש פעולות קולקטיביות ב-StableHLO: ‏ all_gather,‏ all_reduce,‏ all_to_all,‏ collective_broadcast,‏ collective_permute ו-reduce_scatter. כל הפעולות האלה מפצלות את התהליכים ברשת של תהליך StableHLO לקבוצות של תהליכי StableHLO ומבצעות חישוב משותף בכל קבוצת תהליכים, באופן עצמאי מקבוצות תהליכים אחרות.

בתוך כל קבוצת תהליכים, פעולות קולקטיביות עשויות ליצור מחסום סינכרון. צריך להוסיף עוד פרטים, למשל מתי בדיוק מתבצע הסנכרון הזה, איך בדיוק התהליכים מגיעים למחסום הזה ומה קורה אם הם לא מגיעים אליו. הפרטים האלה יתווספו בהמשך (#484).

אם קבוצת התהליכים כוללת תקשורת בין מחיצות, כלומר יש תהליכים בקבוצת התהליכים שמזהי המחיצות שלהם שונים, אז הביצוע של פעולת ה-Collective צריך ערוץ, ופעולת ה-Collective צריכה לספק channel_id חיובי מסוג si64. לא צריך ערוצים לתקשורת בין העתקים.

החישובים שמבוצעים על ידי הפעולות הקולקטיביות הם ספציפיים לפעולות בודדות, והם מתוארים בסעיפים של הפעולות הבודדות שלמעלה. עם זאת, השיטות שלפיהן רשת התהליכים מחולקת לקבוצות תהליכים משותפות בין הפעולות האלה ומתוארות בקטע הזה. באופן רשמי יותר, StableHLO תומך בארבע השיטות הבאות.

cross_replica

רק תקשורת בין רפליקות מתרחשת בתוך כל קבוצת תהליכים. האסטרטגיה הזו מקבלת את replica_groups – רשימה של רשימות של מזהי עותקים – ומחשבת את המכפלה הקרטזית של replica_groups ב-partition_ids. ‫replica_groups חייב להכיל רכיבים ייחודיים ולכלול את כל replica_ids. באופן רשמי יותר, באמצעות תחביר Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הפונקציה cross_replica תחזיר [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

רק תקשורת בין מחיצות מתרחשת בתוך כל קבוצת תהליכים. האסטרטגיה הזו מקבלת את partition_groups – רשימה של רשימות של מזהי מחיצות – ומחשבת את המכפלה הקרטזית של partition_groups ב-replica_ids. ‫partition_groups חייב להכיל רכיבים ייחודיים ולכלול את כל partition_ids. באופן רשמי יותר, באמצעות תחביר Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, עבור partition_groups = [[0, 1]] ו-num_replicas = 4, הפונקציה cross_partition תחזיר [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

תקשורת בין עותקים ותקשורת בין מחיצות יכולות להתרחש בכל קבוצת תהליכים. שיטת הבידינג הזו מקבלת replica_groups – רשימה של רשימות של מזהי עותקים – ומחשבת את המכפלה הקרטזית של כל replica_group לפי partition_ids. ‫replica_groups חייב להכיל רכיבים ייחודיים ולכלול את כל replica_ids. באופן רשמי יותר, באמצעות תחביר Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הפונקציה cross_replica_and_partition תחזיר [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

שיטת הבידינג הזו מקבלת flattened_id_groups – רשימה של רשימות של מזהי תהליכים 'שטוחים' בצורה של replica_id * num_partitions + partition_id – והופכת אותם למזהי תהליכים. ‫flattened_id_groups חייב להכיל רכיבים ייחודיים ולכלול את כל process_ids. באופן רשמי יותר, באמצעות תחביר Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 ו-num_partitions = 2, הפונקציה flattened_ids תחזיר [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

דיוק

בשלב הזה, StableHLO לא מספקת ערבויות לגבי דיוק מספרי, אבל יכול להיות שזה ישתנה בעתיד (‎#1156).

סמנטיקה של ביצוע פעולה עם כימות

הפרשנות של פעולות כמותיות ב-StableHLO עשויה להשתנות בהתאם לדרישות וליכולות של החומרה. לדוגמה, יכול להיות שחלק מהחומרה יבחר לפרש פעולות קוונטיות באמצעות אסטרטגיה של 'ביטול קוונטיזציה, ביצוע פעולה של נקודה צפה ולבסוף קוונטיזציה'. אחרים עשויים לבצע את כל החישוב באמצעות אריתמטיקה של מספרים שלמים. לכן, הפרשנות של פעולות StableHLO שעברו קוונטיזציה נקבעת באופן בלעדי על ידי ההטמעה הספציפית. הפרשנות של הכמותיות ההיברידית (#1575) צריכה להתבסס על הסמנטיקה שלה כפי שמוגדר במפרט (דרך 1792).

שגיאות

תוכניות StableHLO עוברות אימות באמצעות קבוצה נרחבת של אילוצים עבור פעולות נפרדות, מה שמונע הרבה סוגים של שגיאות לפני זמן הריצה. עם זאת, עדיין יכולות להיות שגיאות, למשל בגלל הצפה של מספרים שלמים, גישה מחוץ לגבולות וכו'. אלא אם מצוין אחרת, כל השגיאות האלה גורמות להתנהגות שמוגדרת בהטמעה, אבל זה עשוי להשתנות בעתיד (#1157).

חריגים של נקודה צפה

יוצא מן הכלל לכלל הזה הוא חריג של נקודה צפה בתוכניות StableHLO, שבהן ההתנהגות מוגדרת היטב. פעולות שגורמות לחריגות שמוגדרות בתקן IEEE-754 (פעולה לא חוקית, חלוקה באפס, הצפה, חוסר קיבולת או חריגות לא מדויקות) מפיקות תוצאות ברירת מחדל (כפי שמוגדר בתקן) וממשיכות את הביצוע בלי להעלות את דגל הסטטוס המתאים. זה דומה לטיפול בחריגות raiseNoFlag מהתקן. חריגים לפעולות לא סטנדרטיות (למשל, אריתמטיקה מורכבת ופונקציות טרנסצנדנטיות מסוימות) מוגדרים בהטמעה.

אי התאמות בצורות

‫StableHLO תומך בטנסורים עם צורה דינמית. עם זאת, הצורות צריכות להיות זהות בזמן הריצה, אחרת ההתנהגות לא מוגדרת. ‫StableHLO לא מספקת באופן מפורש אופרטור שיכול לקבוע שלטנסור יש צורה מסוימת בזמן הריצה. המפיק אחראי ליצירת קוד נכון.

לדוגמה, התוכנית הבאה תקינה. עם זאת, בזמן הריצה, הצורות המדויקות של %arg0 ושל %arg1 חייבות להיות זהות, אחרת ההתנהגות של התוכנית לא מוגדרת:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

סימון

כדי לתאר את התחביר, במסמך הזה נעשה שימוש בגרסה של תחביר EBNF לפי תקן ISO (ISO/IEC 14977:1996,‏ Wikipedia) עם שני שינויים: 1) כללים מוגדרים באמצעות ::= ולא באמצעות =,

‫2) שרשור מבוטא באמצעות הצמדה ולא באמצעות ,.

כדי לתאר סמנטיקה (כלומר, בקטעים Types (סוגים), Constants (קבועים) ו-Ops (פעולות)), אנחנו משתמשים בנוסחאות שמבוססות על תחביר Python עם תמיכה מורחבת בביטוי תמציתי של פעולות על מערכים, כפי שמתואר בהמשך. השיטה הזו מתאימה לקטעי קוד קטנים, אבל במקרים נדירים שבהם נדרשים קטעי קוד גדולים יותר, אנחנו משתמשים בתחביר Python רגיל, שתמיד מוצג באופן מפורש.

נוסחאות

נבחן איך נוסחאות פועלות על סמך דוגמה מהdot_general מפרט. אחת המגבלות של הפעולה הזו נראית כך: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

השמות שמשמשים בנוסחה הזו מגיעים משני מקורות: 1) פונקציות גלובליות, כלומר dim, 2) הגדרות של רכיב התוכנית המתאים, כלומר lhs, lhs_batching_dimensions, rhs וקלט rhs_batching_dimensions שמוגדרים בקטע Inputs (קלט) של dot_general.

כמו שצוין למעלה, התחביר של הנוסחה הזו מבוסס על Python עם כמה הרחבות שנועדו לפשט את הנוסחה. כדי להבין את הנוסחה, נמיר אותה לתחביר של Python רגיל.

א) בנוסחאות האלה, אנחנו משתמשים ב-= כדי לייצג שוויון, ולכן השלב הראשון לקבלת תחביר Python הוא להחליף את = ב-==, כמו בדוגמה הבאה: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

ב) הנוסחאות האלה תומכות גם בסימן שלוש הנקודות (...), שהופך ביטויים סקלריים לביטויים טנסוריים. בקיצור, f(xs...) בערך אומר "לכל סקלר x בטנזור xs, מחשבים סקלר f(x) ואז מחזירים את כל התוצאות הסקלריות האלה יחד כתוצאת טנזור". בתחביר Python רגיל, הנוסחה לדוגמה הופכת ל: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

בזכות סימני האליפסה, לעיתים קרובות אפשר להימנע מעבודה ברמה של סקלרים בודדים. עם זאת, במקרים מורכבים מסוימים, יכול להיות שנעשה שימוש בתחביר לא רשמי ברמה נמוכה יותר, כמו בנוסחה start_indices[bi0, ..., :, ..., biN] מהמפרט gather. כדי לשמור על תמציתיות, אנחנו לא מספקים פורמליזם מדויק לתרגום תחביר כזה ל-Python רגיל, בתקווה שעדיין אפשר להבין אותו באופן אינטואיטיבי על בסיס כל מקרה לגופו. אם יש נוסחאות ספציפיות שנראות לכם לא ברורות, אתם יכולים להודיע לנו ואנחנו ננסה לשפר אותן.

בנוסף, תשימו לב שבנוסחאות נעשה שימוש בסימן שלוש הנקודות כדי להרחיב כל מיני רשימות, כולל טנסורים, רשימות של טנסורים (שיכולות להיווצר ממספר משתנה של טנסורים) וכו'. זה עוד תחום שבו אנחנו לא מספקים פורמליזם מדויק (למשל, רשימות הן לא חלק ממערכת הסוגים של StableHLO), אלא מסתמכים על מובנות אינטואיטיבית.

ג) כלי נוסף וחשוב שבו אנחנו משתמשים הוא שידור מרומז. למרות שערכת הפעולות (opset) של StableHLO לא תומכת בשידור מרומז, הנוסחאות כן תומכות, גם כדי לשמור על תמציתיות. בקיצור, אם נעשה שימוש בסקלר בהקשר שבו צפוי טנסור, הסקלר מועבר לשידור לצורה הצפויה.

כדי להמשיך את הדוגמה של dot_general, הנה עוד אילוץ: 0 <= lhs_batching_dimensions < rank(lhs). כפי שמוגדר במפרט dot_general, ‏ lhs_batching_dimensions הוא טנזור, אבל 0 ו-rank(lhs) הם סקלרים. אחרי שנחיל שידור מרומז, הנוסחה תהפוך ל-[0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

כשמחילים את הנוסחה הזו על פעולה מסוימת של dot_general, היא מחזירה טנזור של ערכים בוליאניים. כשמשתמשים בנוסחאות כאילוצים, האילוץ מתקיים אם הנוסחה מחזירה את הערך true או טנסור שמכיל רק רכיבים עם הערך true.

שמות

בנוסחאות, היקף לקסיקלי כולל: 1) פונקציות גלובליות, 2) הגדרות של חברים,

‫3) הגדרות מקומיות. בהמשך מופיעה רשימה של פונקציות גלובליות. רשימת ההגדרות של הרכיבים תלויה ברכיב התוכנית שאליו מוחל הסימון:

  • בפעולות, ההגדרות של חברים כוללות שמות שמופיעים בקטעים 'קלט' ו'פלט'.
  • בכל שאר המקרים, ההגדרות של חברי המועדון כוללות חלקים מבניים של רכיב התוכנית, שנקראים על שם המונחים הלא-סופיים המתאימים של EBNF. ברוב המקרים, השמות של החלקים המבניים האלה מתקבלים על ידי המרת השמות של הלא-טרמינלים ל-snake case (למשל, IntegerLiteral => integer_literal), אבל לפעמים השמות מקוצרים בתהליך (למשל, QuantizationStorageType => storage_type). במקרה כזה, השמות מוצגים באופן מפורש בדומה לקטעים 'קלט' / 'פלט' במפרטים של פעולות.
  • בנוסף, הגדרות החברות תמיד כוללות את התו self כדי להפנות לרכיב התוכנית המתאים.

ערכים

כשמבצעים הערכה של נוסחאות, הן פועלות עם סוגי הערכים הבאים: ‫1) Value (ערכים בפועל, למשל dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; תמיד ידועים הסוגים שלהם), ‫2) Placeholder (ערכים עתידיים, למשל lhs, ‏ rhs או result; הערכים בפועל שלהם עדיין לא ידועים, רק הסוגים שלהם ידועים), ‫3) Type (סוגים כפי שהוגדרו בקטע 'סוגים'), ‫4) Function (פונקציות גלובליות כפי שהוגדרו בקטע 'פונקציות').

בהתאם להקשר, השמות עשויים להתייחס לערכים שונים. באופן ספציפי יותר, בקטע 'סמנטיקה' של פעולות (ומקבילות לרכיבי תוכנה אחרים) מוגדרת לוגיקה של זמן ריצה, כך שכל הקלט זמין כ-Value. לעומת זאת, בקטע 'מגבלות' של פעולות (ומקבילות) מוגדרת לוגיקה של 'זמן קומפילציה', כלומר משהו שבדרך כלל מבוצע לפני זמן הריצה, ולכן רק קלטים קבועים זמינים כ-Value וקלטים אחרים זמינים רק כ-Placeholder.

שמות בקטע 'סמנטיקה' בקטע 'מגבלות'
פונקציות גלובליות Function Function
קלט קבוע Value Value
קלט לא קבוע Value Placeholder
פלט Value Placeholder
הגדרות מקומיות תלוי בהגדרה תלוי בהגדרה

נבחן עכשיו דוגמה לפעולה transpose:

%result = "stablehlo.transpose"(%operand) {
  permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32

בפעולה הזו, permutation הוא קבוע, ולכן הוא זמין כ-Value גם בסמנטיקה וגם באילוצים. לעומת זאת, operand ו-result זמינים כ-Value בסמנטיקה, אבל רק כ-Placeholder במגבלות.

פונקציות

בנייה של סוגים

אין פונקציות שאפשר להשתמש בהן כדי ליצור סוגים. במקום זאת, אנחנו משתמשים ישירות בתחביר של סוגים כי הוא בדרך כלל תמציתי יותר. לדוגמה: (tensor<E>, tensor<E>) -> (tensor<E>) במקום function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

פונקציות לפי סוגים

  • הפונקציה element_type מוגדרת בסוגי טנסורים ובסוגי טנסורים שעברו קוונטיזציה, ומחזירה את החלק TensorElementType או QuantizedTensorElementType של TensorType או QuantizedTensorType המתאימים, בהתאמה.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value היא קיצור דרך לis_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool בודק אם אפשר להמיר את הסוג x לסוג y. אם x ו-y הם QuantizedTensorElementType, המבצע חל רק על storage_type. הגרסה הספציפית הזו של קידום מכירות משמשת כרגע בהקשר של חישוב הפחתה (פרטים נוספים זמינים ב-RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. זמין לכל הסוגים. לדוגמה, הפונקציה is_float(x) מחזירה true אם x הוא FloatType. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור של is_type_name(type(x)).

  • max_value(x: Type) -> Value מחזירה את הערך המקסימלי של TensorElementType. אם x הוא לא TensorElementType, הפונקציה מחזירה None.

  • min_value(x: Type) -> Value מחזירה את הערך המינימלי האפשרי של TensorElementType. אם x הוא לא TensorElementType, הפונקציה מחזירה None.

  • member_name(x: Value | Placeholder | Type) -> Any. זמין לכל הגדרות החברים member_name מכל הסוגים. לדוגמה, הפקודה tensor_element_type(x) תחזיר את החלק TensorElementType של TensorType תואם. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור של member_name(type(x)). אם x הוא לא סוג שיש לו חבר מתאים, או ערך או placeholder מסוג כזה, הפונקציה מחזירה את x.None

  • is_empty_algorithm(*args: Type) בודקת אם כל השדות של אלגוריתם הנקודות מוגדרים ל-None. הסיבה לכך היא שלאלגוריתמים של נקודות יש התנהגויות ברירת מחדל שמוגדרות בהטמעה, ולכן ציון ערך ברירת מחדל יהיה שגוי.

בניית ערכים

  • operation_name(*xs: Value | Type) -> Value. זמין לכל הפעולות. לדוגמה, הפונקציה add(lhs, rhs) מקבלת שני ערכי טנסור lhs ו-rhs ומחזירה את הפלט של הערכת הפעולה add עם הקלטים האלה. לפעולות מסוימות, למשל broadcast_in_dim, סוגי הפלט שלהן הם 'נושאי עומס', כלומר נדרשים להערכת פעולה. במקרה כזה, הפונקציה מקבלת את הסוגים האלה כארגומנטים.

פונקציות על ערכים

  • כל האופרטורים והפונקציות של Python זמינים. לדוגמה, אפשר להשתמש בסימון subscription וגם בסימון slicing מ-Python כדי ליצור אינדקס לטנסורים, לטנסורים שעברו קוונטיזציה ולטפלים.

  • to_destination_type(x: Value, destination_type: Type) -> Value מוגדר בטנסורים ומחזיר את הערך המומר של x על סמך type(x) ו-destination_type באופן הבא:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

מתנהל דיון ראשוני על מיזוג הפעולות convert, uniform_quantize ו-uniform_dequantize (#1576). אחרי המיזוג, אין צורך בפונקציה שלמעלה ואפשר להשתמש בשם הפעולה convert במקום.

  • הפונקציה is_nan(x: Value) -> Value מוגדרת על טנסורים ומחזירה true אם כל האלמנטים של x הם NaN או false אחרת. אם x הוא לא טנסור, הפונקציה מחזירה None.

  • הפונקציה is_sorted(x: Value) -> Value מוגדרת על טנסורים ומחזירה true אם הרכיבים של x ממוינים בסדר עולה ביחס לסדר הלקסיקוגרפי העולה של האינדקסים שלהם, או false אחרת. אם x הוא לא טנזור, הפונקציה מחזירה None.

  • הפונקציה is_unique(x: Value) -> Value מוגדרת לטנסורים ומחזירה true אם ל-x אין רכיבים כפולים, או false אחרת. אם x הוא לא טנסור, הפונקציה מחזירה None.

  • הערך member_name(x: Value) -> Any מוגדר לכל הגדרות החברים member_name של כל הערכים. לדוגמה, הפקודה real_part(x) מחזירה את החלק RealPart של ComplexConstant התואם. אם x הוא לא ערך שיש לו חבר מתאים, הפונקציה מחזירה את None.

  • הפונקציה same(x: Value) -> Value מוגדרת על טנסורים ומחזירה true אם כל הרכיבים של x שווים זה לזה, או false אחרת. אם לטנזור אין אלמנטים, זה נחשב כ'כל האלמנטים שווים זה לזה', כלומר הפונקציה מחזירה true. אם x הוא לא טנסור, הפונקציה מחזירה None.

  • הפונקציה split(x: Value, num_results: Value, axis: Value) -> Value מוגדרת על טנסורים ומחזירה פרוסות num_results של x לאורך הציר axis. אם x הוא לא טנסור או dim(x, axis) % num_results != 0, הפונקציה מחזירה None.

  • הפונקציה is_defined_in_parent_scope(x: Value) -> Value מוגדרת במחרוזות ומחזירה את הערך true אם x הוא השם של פונקציה שמוגדרת באותו היקף כמו פונקציית ההורה של הפעולה הרלוונטית.

  • הפונקציה is_namespaced_op_name(x: Value) -> Value מוגדרת במחרוזות ומחזירה true אם x הוא שם אופרטור תקין, כלומר הוא תואם לביטוי הרגולרי הבא: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

חישובים של צורות

  • axes(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value הוא קיצור דרך ל-shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List הוא קיצור דרך ל-list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value מוגדר על טנסורים ומחזיר את האינדקסים של size(x) המתאימים, שמוינו בסדר לקסיקוגרפי עולה, כלומר [0, ..., 0],‏ [0, ..., 1] וכן הלאה עד shape(x) - 1.TensorType אם x הוא לא סוג טנזור, סוג טנזור שעבר קוונטיזציה, ערך או placeholder של אחד מהסוגים האלה, הפונקציה מחזירה את x.None

  • rank(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value מוגדר בקטע 'פונקציות על סוגים' באמצעות member_name.

  • size(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-reduce(lambda x, y: x * y, shape(x)).

חישובי קוונטיזציה

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type היא קיצור דרך לelement_type(baseline_type(x)).

  • baseline_type מוגדר בסוגי טנסור ובסוגי טנסור שעברו קוונטיזציה, והוא משנה אותם ל'בסיס', כלומר לסוג עם אותה צורה אבל עם פרמטרים של קוונטיזציה של סוג האלמנט שמאופסים לערכי ברירת מחדל. הטריק הזה שימושי להשוואה אחידה בין שני סוגים של טנסורים: טנסור רגיל וטנסור שעבר קוונטיזציה. השוואה כזו נדרשת לעיתים קרובות. עבור סוגים שעברו קוונטיזציה, האפשרות הזו מאפשרת להשוות בין סוגים תוך התעלמות מפרמטרים של קוונטיזציה, כלומר, כל הפרמטרים shape,‏ storage_type,‏ expressed_type,‏ storage_min,‏ storage_max ו-quantization_dimension (עבור סוג שעבר קוונטיזציה לכל ציר) חייבים להיות זהים, אבל הפרמטרים scales ו-zero points יכולים להיות שונים.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize מוגדר בסוגי טנסורים שעברו קוונטיזציה והופך אותם לסוגי טנסורים של נקודה צפה. ההמרה מתבצעת על ידי המרת רכיבים כמותיים שמייצגים ערכים שלמים של סוג האחסון לערכים תואמים של נקודה צפה של הסוג המבוטא, באמצעות נקודת האפס והקנה מידה שמשויכים לסוג הרכיב הכמותי.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize מוגדר בסוגי טנסור של נקודה צפה והופך אותם לסוגי טנסור שעברו קוונטיזציה. ההמרה מתבצעת על ידי המרת ערכים מסוג נקודה צפה (floating-point) של הסוג המפורש לערכים שלמים תואמים של סוג האחסון, באמצעות נקודת האפס והקנה מידה שמשויכים לסוג האלמנט המכומת.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize משמשת לציון חישובים לפי רכיבים בטנסורים שעברו קוונטיזציה. היא מבצעת דה-קוונטיזציה, כלומר הופכת רכיבים שעברו קוונטיזציה לסוגים המפורשים שלהם, ואז מבצעת פעולה, ואז מבצעת קוונטיזציה, כלומר הופכת את התוצאות בחזרה לסוגי האחסון שלהן. בשלב הזה, הפונקציה הזו פועלת רק עבור קוונטיזציה לכל טנסור. אנחנו עובדים על כימות לכל ציר (‎#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op משמש לציון קוונטיזציה רק של משקלים עבור פעולה היברידית שמקבלת lhs בנקודה צפה ו-rhs בסוגים שעברו קוונטיזציה. היא מבצעת דה-קוונטיזציה של קלטים שעברו קוונטיזציה לסוגים שלהם ומבצעת חישוב בערך float. סוג הרכיב של טנזור lhs מסוג float וסוג הרכיב של טנזור rhs מסוג quantized צריכים להיות זהים.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

חישובים ברשת

  • cross_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica למעלה.

  • cross_replica(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica למעלה.

  • cross_replica_and_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica_and_partition למעלה.

  • flattened_ids(replica_groups: Value) -> Value. אפשר לעיין בקטע 'flattened_ids' למעלה.

דינמיות

לערכים ב-StableHLO יכולים להיות גדלים דינמיים של מאפיינים, למשל tensor<?xi64>. עם זאת, לערכים ב-StableHLO לא יכול להיות מספר דינמי של מימדים (דינמיות לא מדורגת, למשל tensor<*xi64>). מותר להשתמש באופרנדים ובתוצאות בגדלים דינמיים של מימדים, גם אם יש אילוצים על הגדלים. אם אפשר, המערכת תאמת את האילוצים באופן סטטי. אחרת, היא תדחה את האימות לזמן הריצה, ואי-התאמות יובילו להתנהגות לא מוגדרת. למטה מפורטות דוגמאות.

אי התאמה בצורות של פעולות אלמנטריות אוניטריות

נניח שיש לנו את תוכנית הצעצוע הבאה:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

תוכנית כזו היא לא שגרתית, כי בדרך כלל יודעים מה הצורה של התוצאה אבל לא מה הצורה של הקלט. למרות זאת, זו תוכנית StableHLO תקינה. אי אפשר לאמת באופן סטטי את הפעולה abs בתוכנית הזו, כי הצורה המדויקת של האופרנד לא ידועה. עם זאת, הצורות תואמות בוודאות, ואפשר לבדוק את זה באופן סטטי: יכול להיות ש-? יהפוך ל-2 בזמן הריצה, ולא תהיה בעיה. עם זאת, יכול להיות ש-? יהיה מספר שלם אחר, ובמקרה כזה ההתנהגות לא מוגדרת.

הערה: אם הגודל של מאפיין הוא דינמי בתוצאה, לא יכול להיות מצב של התנהגות לא מוגדרת. למעשה, אין גודל 'צפוי', ולכן לא יכול להיות חוסר התאמה.

אי התאמה בצורות של פעולות בינאריות ברמת הרכיב

נניח שיש לנו את תוכנית הצעצוע הבאה:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

בפעולות בינאריות ברמת הרכיב, הצורות של הקלט והתוצאה חייבות להיות זהות בזמן הריצה. בזמן ההידור, המימדים הסטטיים צריכים להיות שווים, אחרת הם צריכים להיות רק תואמים. אם אחד מהמאפיינים הוא דינמי בקלט, יכול להיות שההתנהגות בזמן הריצה לא תהיה מוגדרת, כי הגודל הדינמי לא יכול להתאים לגודל התואם באופרנד השני (סטטי או דינמי). אם כל נתוני הקלט הם סטטיים, לא משנה אם התוצאה דינמית או לא: מאפיינים סטטיים ידועים ייבדקו באופן סטטי, ומאפיינים דינמיים לא יחייבו אילוצים.

אי התאמות בצורות של פעולות שצורת הפלט שלהן משמשת כאופרנד

נניח שיש לנו את תוכנית הצעצוע הבאה:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

הערכים באופרנד של הצורה בזמן הריצה צריכים להתאים לצורה של התוצאה, אחרת ההתנהגות לא מוגדרת. כלומר, בזמן הריצה הערך של %arg0 חייב להיות dense<[3, 4]> : tensor<2xi32>. אם האופרנד של הצורה הוא קבוע, אפשר לאמת את זה באופן סטטי. אם צורת התוצאה דינמית לחלוטין, לא יכול להיות חוסר התאמה.