מפרט StableHLO

StableHLO היא פעולה שהוגדרה לפעולות ברמה גבוהה (HLO) במודלים של למידת מכונה (ML). ‏StableHLO פועלת כשכבת ניידות בין מסגרות שונות של למידת מכונה ומהדרי למידת מכונה: מסגרות של למידת מכונה שיוצרות תוכניות StableHLO תואמות למהדרי למידת מכונה שמשתמשים בתוכניות StableHLO.

המטרה שלנו היא לפשט ולהאיץ את הפיתוח של למידת המכונה על ידי יצירת יכולת פעולה הדדית בין מסגרות שונות של למידת מכונה (כמו TensorFlow,‏ JAX ו-PyTorch) ומהדרי למידת מכונה (כמו XLA ו-IREE). לשם כך, המסמך הזה מכיל מפרט של שפת התכנות StableHLO.

המפרט הזה כולל שלושה קטעים עיקריים. בקטע Programs מתוארים המבנה של תוכניות StableHLO, שמכילות פונקציות StableHLO שמכילות בעצמן פעולות StableHLO. במסגרת המבנה הזה, הקטע Ops מציין את הסמנטיקה של פעולות ספציפיות. בקטע ביצוע מוצג סמנטיקה לכל הפעולות האלה שמתבצעות ביחד בתוכנית. בסיום, בקטע סימון מוסבר על הסימון שנעשה בו שימוש לאורך המפרט.

כדי לצפות במפרט מגרסה קודמת של StableHLO, פותחים את המאגר בגרסה המתויגת הרצויה. לדוגמה, StableHLO v0.19.0 Spec. כדי לראות את השינויים שהתרחשו בכל גרסה משנית של StableHLO, אפשר לעיין ביומן הגרסאות ב-VhloDialect.td.

תוכניות

Program ::= {Func}

תוכניות StableHLO מורכבות ממספר שרירותי של פונקציות StableHLO. בהמשך מופיעה תוכנית לדוגמה עם פונקציה @main שיש לה 3 מקורות קלט (%image,‏ %weights ו-%bias) ופלט אחד. גוף הפונקציה מכיל 6 פעולות.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

פונקציות

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

פונקציות StableHLO (שנקראות גם פונקציות בעלות שם) כוללות מזהה, קלט/פלט וגוף. בעתיד אנחנו מתכננים להוסיף מטא-נתונים נוספים לפונקציות כדי לשפר את התאימות ל-HLO (#425,‏ #626,‏ #740,‏ #744).

מזהים

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

מזהי StableHLO דומים למזהים בשפות תכנות רבות, עם שתי ייחודיות: 1) לכל המזהים יש תגים שמבחינים בין סוגים שונים של מזהים, 2) מזהי ערכים יכולים להיות מספריים לחלוטין, כדי לפשט את היצירה של תוכנות StableHLO.

סוגים

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

סוגי StableHLO מסווגים לסוגי ערכים (שנקראים גם סוגי מחלקה ראשונה) שמייצגים ערכים של StableHLO, ולסוגי ערכים שאינם ערכים שמתארים רכיבי תוכנית אחרים. סוגי StableHLO דומים לסוגים בשפות תכנות רבות. המאפיין העיקרי הוא האופי הספציפי לדומיין של StableHLO, מה שמוביל לתוצאות חריגות (למשל, סוגים סקלריים הם לא סוגי ערכים).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

סוגי טינסורים מייצגים טינסורים, כלומר מערכי משנה מרובים. לכל רכיב יש צורה וסוג רכיב. הצורה מייצגת גדלים של מאפיינים שאינם שליליים או לא ידועים, בסדר עולה של המאפיינים התואמים (שנקראים גם צירים) שממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא דירוג. לדוגמה, tensor<2x3xf32> הוא סוג t e n s o r f l o w, עם הצורה 2x3 וסוג הרכיב f32. יש לו שני מאפיינים (או, במילים אחרות, שני צירים) – המאפיין ה-0 והמאפיין ה-1 – שהגדלים שלהם הם 2 ו-3. הדירוג שלו הוא 2.

צורות יכולות להיות לא ידועות באופן חלקי או מלא (דינמית), למשל, tensor<?x2xf64> לא ידועה באופן חלקי ו-tensor<?x?xf64> לא ידוע לחלוטין. גדלים של מאפיינים דינמיים מיוצגים באמצעות ?. לא ניתן לבטל את הדירוג של צורות.

בעתיד, אנחנו מתכננים להרחיב את השימוש בסוגי טינזור מעבר לגדלים של המאפיינים ולסוגי הרכיבים שלהם. לדוגמה, כדי לכלול פריסות (#629) ואת חלקן (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
שם סוג אילוצים
storage_type טיפוס שלם (C1-C3),‏ (C8)
storage_min קבוע שלם (C1), (C3), (C7)
storage_max מספר שלם קבוע (C2), (C3), (C7)
expressed_type סוג נקודה צפה (floating-point) (C4)
quantization_dimension ערך קבוע שלם אופציונלי (C10-C12)
scales מספר משתנה של קבועים עם נקודה צפה (float) (C4-C6),‏ (C9),‏ (C10),‏ (C13)
zero_points מספר משתנה של קבועים של מספרים שלמים (C7-C9)

סוגי רכיבים מרובעניים מייצגים ערכים שלמים של סוג אחסון בטווח מ-storage_min עד storage_max (כולל), שתואמים לערכים של נקודה צפה של סוג מפורש. בערך נתון i, את הערך המתאים בנקודה הצפה f אפשר לחשב כ-f = (i - zero_point) * scale, ו-scale ו-zero_point נקראים פרמטרים של קוונטיזציה. המאפיינים storage_min ו-storage_max הם אופציונליים ב-grammar, אבל ערכי ברירת המחדל שלהם הם min_value(storage_type) ו-max_value(storage_type), בהתאמה. לסוגים של רכיבים שעברו קצירה חלות המגבלות הבאות:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) אם is_empty(quantization_dimension), אז size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

בשלב זה, QuantizationScale הוא קבוע של נקודה צפה, אבל יש עניין רב בסולמות שמבוססים על מספרים שלמים, שמיוצגים באמצעות מכפילים והזזות. אנחנו מתכננים לבדוק את הנושא הזה בעתיד הקרוב (#1404).

יש דיון מתמשך על הסמנטיקה של QuantizationZeroPoint, כולל הסוג, הערכים והשאלה האם יכולה להיות רק נקודה אחת או כמה נקודות אפסיות בסוג img_tensor. בהתאם לתוצאות הדיון, המפרט לגבי אפס נקודות עשוי להשתנות בעתיד (#1405).

דיון מתמשך נוסף נוגע לסמנטיקה של QuantizationStorageMin ו-QuantizationStorageMax, כדי לקבוע אם צריך להחיל אילוצים על הערכים האלה ועל הערכים של טינסורים מקוטנים (#1406).

לבסוף, אנחנו מתכננים לבדוק את ייצוג של סולמות לא ידועים ונקודות אפס, בדומה לאופן שבו אנחנו מתכננים לבדוק את ייצוג של גדלים לא ידועים של מאפיינים (#1407).

סוגי טינסורים מרובים (quantized) מייצגים טינסורים עם רכיבים מרובים (quantized). הטנזורים האלה זהים לטנזורים רגילים, מלבד העובדה שלרכיבים שלהם יש סוגי רכיבים שמוגדרים בקוונטים, במקום סוגי רכיבים רגילים.

בטנסורים שמבוצעת בהם כימות, אפשר לבצע כימות לכל טנסור, כלומר להשתמש ב-scale וב-zero_point אחד לכל הטנסור, או לבצע כימות לכל ציר, כלומר להשתמש במספר ערכים של scales ו-zero_points, לכל זוג של פרוסה של מאפיין מסוים quantization_dimension. באופן רשמי יותר, בטנסור t עם כיווץ לכל ציר יש dim(t, quantization_dimension) פרוסות של quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] וכו'. כל הרכיבים בפרוסת i משתמשים ב-scales[i] וב-zero_points[i] כפרמטרים שלהם לצורך כיווץ. לסוגים של Tensor כמות יש את האילוצים הבאים:

  • לכימות לפי טנזור:
    • אין אילוצים נוספים.
  • לכימות לפי ציר:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

סוגי אסימונים מייצגים אסימונים, כלומר ערכים אטומים שנוצרים ומנוצלים על ידי פעולות מסוימות. האסימונים משמשים להגדרת סדר הביצוע של פעולות, כפי שמתואר בקטע ביצוע.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

סוגי צמדי-ערכי-מפתח (Tuple) מייצגים צמדי-ערכי-מפתח, כלומר רשימות הטרוגניות. צמדי מפתח/ערך הם תכונה מדור קודם שקיימת רק לצורך תאימות ל-HLO. ב-HLO, צמדים משמשים לייצוג משתני קלט ופלט שונים. ב-StableHLO יש תמיכה בקלט ופלט שונים, והשימוש היחיד בפלטים ב-StableHLO הוא לייצג באופן מקיף HLO ABI כאשר למשל T, tuple<T> ו-tuple<tuple<T>> עשויים להיות שונים באופן מהותי בהתאם להטמעה מסוימת. בעתיד, אנחנו מתכננים לבצע שינויים ב-HLO ABI, שיכולים לאפשר לנו להסיר סוגים שונים של Tuple מ-StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

סוגי רכיבים מייצגים רכיבים של סוגי טינסורים. בניגוד לשפות תכנות רבות, הסוגים האלה הם לא סוגים ברמה ראשונה ב-StableHLO. המשמעות היא שתוכניות StableHLO לא יכולות לייצג באופן ישיר ערכים מהסוגים האלה (כתוצאה מכך, מקובל לייצג ערכים סקלרים מסוג T באמצעות ערכים של טינסורים ב-0 מימדים מסוג tensor<T>).

  • סוג בוליאני מייצג את הערכים הבוליאניים true ו-false.
  • סוגי מספרים שלמים יכולים להיות עם סימן (si) או ללא סימן (ui), ולכלול את אחד מרוחב הביטים הנתמכים (2,‏ 4,‏ 8,‏ 16,‏ 32 או 64). סוגי siN עם סימן מייצגים ערכים שלמים מ--2^(N-1) עד 2^(N-1)-1, וסוגי uiN ללא סימן מייצגים ערכים שלמים מ-0 עד 2^N-1.
  • סוגי נקודה צפה יכולים להיות אחד מהסוגים הבאים:
  • סוגי רכיבים מורכבים מייצגים ערכים מורכבים שיש להם חלק ממשי וחלק מדומה מאותו סוג רכיב. הסוגים המורכבים הנתמכים הם complex<f32> (שני החלקים הם מסוג f32) ו-complex<f64> (שני החלקים הם מסוג f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

סוגי פונקציות מייצגים גם פונקציות בעלות שם וגם פונקציות אנונימיות. יש להם סוגי קלט (רשימת הסוגים שמשמאל ל-->) וסוגי פלט (רשימת הסוגים בצד שמאל של ->). בשפות תכנות רבות, סוגי הפונקציות הם מחלקה ראשונה, אבל לא ב-StableHLO.

StringType ::= 'string'

String type מייצג רצפים של בייטים. בניגוד לשפות תכנות רבות, סוג המחרוזת הוא לא סוג ברמה ראשונה ב-StableHLO, והוא משמש רק לציון מטא-נתונים סטטיים של רכיבי תוכנית.

תפעול

פעולות StableHLO (שנקראות גם ops) מייצגות קבוצה סגורה של פעולות ברמה גבוהה במודלים של למידת מכונה. כפי שצוין למעלה, התחביר של StableHLO מבוסס במידה רבה על MLIR. זו לא בהכרח החלופה הארגונומית ביותר, אבל היא מתאימה ביותר למטרה של StableHLO – ליצור יכולת פעולה הדדית בין מסגרות למידת מכונה לבין מהדרי למידת מכונה.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

לפעולות StableHLO (שנקראות גם ops) יש שם, קלט/פלט וחתימה. השם מורכב מהתחילית stablehlo. וממונח mnemotechnic שמזהה באופן ייחודי אחת מהפעולות הנתמכות. בהמשך מופיעה רשימה מקיפה של כל הפעולות הנתמכות.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

צוות התפעול צורך קלט ויוצר פלט. הקלטות מסווגות לפי ערכים של קלט (שמתקבלים במהלך הביצוע), פונקציות קלט (שסופקו באופן סטטי, כי ב-StableHLO פונקציות הן לא ערכים ברמה ראשונה) ומאפייני קלט (שסופקו גם באופן סטטי). סוג הקלט והפלט שהפעולה צורכת ומייצרת תלוי במונח המנומני שלה. לדוגמה, הפעולה add צורכת 2 ערכי קלט ומפיקה ערך פלט אחד. לעומת זאת, הפעולה select_and_scatter צורכת 3 ערכי קלט, 2 פונקציות קלט ו-3 מאפייני קלט.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

פונקציות קלט (שנקראות גם פונקציות אנונימיות) דומות מאוד לפונקציות בעלות שם, מלבד העובדות הבאות: 1) אין להן מזהה (לכן השם 'אנונימיות'), 2) הן לא מגדירות סוגי פלט (סוגי הפלט משוערים מהפעולה return בתוך הפונקציה).

התחביר של פונקציות הקלט כולל חלק שלא בשימוש כרגע (ראו את הפונקציה Unused למעלה), שנועד לצורך תאימות ל-MLIR. ב-MLIR יש מושג כללי יותר של 'אזורים', שיכולים לכלול כמה 'בלוקים' של פעולות שמחוברים יחד באמצעות פעולות מעבר (jump). למקטעים האלה יש מזהי Unused שמתאימים לסביבת הייצור, כדי שאפשר יהיה להבדיל ביניהם. ב-StableHLO אין פעולות קפיצה, ולכן החלק התואם של תחביר MLIR לא בשימוש (אבל הוא עדיין קיים).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

מאפייני קלט כוללים שם וערך שהוא אחד מהקבועים הנתמכים. הם הדרך העיקרית לציין מטא-נתונים סטטיים לרכיבי תוכנה. לדוגמה, הפעולה concatenate משתמשת במאפיין dimension כדי לציין את המאפיין שבו ערכים הקלט שלה מקושרים. באופן דומה, הפונקציה slice משתמשת במספר מאפיינים כמו start_indices ו-limit_indices כדי לציין את הגבולות שבהם נעשה שימוש כדי לחתוך את ערך הקלט.

בשלב זה, תוכניות StableHLO בטבע מכילות לפעמים מאפיינים שלא מתוארים במסמך הזה. בעתיד אנחנו מתכננים לשלב את המאפיינים האלה ב-opset של StableHLO או לאסור עליהם להופיע בתוכניות של StableHLO. בינתיים, הנה רשימת המאפיינים האלה:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • מטא-נתונים של מיקום (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

חתימת הפעולה מורכבת מהסוגים של כל ערכי הקלט (רשימת הסוגים בצד ימין של ->) ומהסוגים של כל ערכי הפלט (רשימת הסוגים בצד ימין של ->). באופן מדויק, סוגי הקלט הם יתירות, וסוגי הפלט הם כמעט תמיד יתירות גם כן (כי ברוב הפעולות של StableHLO, אפשר להסיק את סוגי הפלט מהקלט). עם זאת, חתימה של אופרטורים היא חלק מכוון מתחביר StableHLO לצורך תאימות ל-MLIR.

בהמשך מופיעה דוגמה לפקודה שהמפתח שלה הוא select_and_scatter. הפונקציה צורכת 3 ערכים של קלט (%operand,‏ %source ו-%init_value), 2 פונקציות קלט ו-3 מאפייני קלט (window_dimensions,‏ window_strides ו-padding). שימו לב שהחתימה של הפעולה כוללת רק את סוגי הערכים של הקלט (אבל לא את סוגי הפונקציות והמאפיינים של הקלט שסופקו בשורה).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

קבועים

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

לקבועים מסוג StableHLO יש ליריקה וסוג, שמייצגים יחד ערך StableHLO. באופן כללי, הסוג הוא חלק מהתחביר של הקבוע, אלא אם הוא ברור (למשל, לקבוע בוליאני יש תמיד את הסוג i1, ואילו לקבוע שלם יכולים להיות כמה סוגים אפשריים).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

קבועים בוליאניים מייצגים ערכים בוליאניים true ו-false. למשתני קבוע בוליאני יש את הסוג i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

קבועים שלמים מייצגים ערכים שלמים באמצעות מחרוזות שמשתמשות בסימון עשרוני או בשש עשרה ספרות. אין תמיכה בבסיסים אחרים, למשל בינארי או אוקטלי. למשתני קבוע שלם יש את האילוצים הבאים:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

קבועים של נקודה צפה מייצגים ערכים של נקודה צפה באמצעות מחרוזות שמשתמשות ברישום עשרוני או ברישום מדעי. בנוסף, אפשר להשתמש בסימון הקסדצימלי כדי לציין ישירות את הביטים הבסיסיים בפורמט הנקודה הצפה של הסוג המתאים. למשתני קבועים של נקודה צפה יש את האילוצים הבאים:

  • (C1) אם משתמשים בסימון שאינו הקסדצימלי, is_wellformed(float_literal, float_type).
  • (C2) אם נעשה שימוש בסימון הקסדצימלי, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

קבועים מורכבים מייצגים ערכים מורכבים באמצעות רשימות של חלק ממשי (מגיע קודם) וחלק דמיוני (מגיע שני). לדוגמה, הערך (1.0, 0.0) : complex<f32> מייצג את הערך 1.0 + 0.0i, והערך (0.0, 1.0) : complex<f32> מייצג את הערך 0.0 + 1.0i. הסדר שבו החלקים האלה מאוחסנים בזיכרון מוגדר לפי ההטמעה. למשתני קבועים מורכבים יש את האילוצים הבאים:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

קבועים של טינסורים מייצגים ערכים של טינסורים באמצעות רשימות בתצוגת עץ שצוינו באמצעות סימון NumPy. לדוגמה, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> מייצג ערך t e n s o r f l o w, עם המיפוי הבא מאינדקסים לאלמנטים: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. הסדר שבו הרכיבים האלה מאוחסנים בזיכרון מוגדר בהטמעה. למשתני הטנזור הקבועים יש את האילוצים הבאים:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), כאשר:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), כאשר:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • אחרת, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

קבועים של טנזור כמותי מייצגים ערכים של טנזור כמותי באמצעות סימון זהה לזה של קבועי טנזור, עם אלמנטים שצוינו כקבועים של סוג האחסון שלהם. למשתני טינסור מרובע (quantized tensor) יש את האילוצים הבאים:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

מחרוזות ליסטריות מורכבות מבייטים שצוינו באמצעות תווים מסוג ASCII ורצפי בריחה. הם לא תלויים בקידוד, ולכן הפרשנות של הבייטים האלה מוגדרת בהטמעה. ליטרלים של מחרוזות הם מסוג string.

תפעול

בטן

סמנטיקה

הפונקציה מבצעת פעולת abs לפי רכיבים בטנסור operand ומייצרת טנסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • למספרים שלמים בעלי סימן: מודולוס של מספר שלם.
  • למספרים מסוג float: abs מ-IEEE-754.
  • למספרים מרוכבים: מודולוס מורכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(abs, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור של מספר שלם עם סימן, מספר נקודה צפה או מספר מרוכב, או טינסור מקודד לכל טינסור (C1-C2)

פלט

שם סוג מגבלות
result טינסור של מספר שלם חתום או של מספר נקודה צפה (floating-point), או טינסור מקודד לכל טינסור (C1-C2)

אילוצים

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) מוגדר בתור:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • baseline_element_type(operand) אחרת.

דוגמאות

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

 דוגמאות נוספות

add

סמנטיקה

הפונקציה מבצעת הוספה של שני טינסורים lhs ו-rhs לפי רכיבים, ומפיקה טינסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • למשתנים בוליאניים: 'או' לוגי.
  • למספרים שלמים: חיבור של מספרים שלמים.
  • לצפים: addition מ-IEEE-754.
  • למספרים מרוכבים: חיבור מורכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(add, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור או טינסור מצטבר (C1-C6)
(I2) rhs טינסור או טינסור מצטבר (C1-C5),‏ (C7)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C1-C7)

אילוצים

  • אם הפעולה משתמשת בטנסורים לא מקובצים:
    • (C1) type(lhs) = type(rhs) = type(result).
  • אם הפעולה משתמשת בטנסורים מרובים (quantization):
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) אם is_per_axis_quantized(lhs), אז quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = quantization_dimension(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

 דוגמאות נוספות

after_all

סמנטיקה

מוודא שהפעולות שמניבות את inputs מבוצעות לפני כל פעולה שתלויה ב-result. ביצוע הפעולה הזו לא עושה כלום, היא קיימת רק כדי ליצור יחסי תלות בין הנתונים מ-result ל-inputs.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של token

פלט

שם סוג
result token

דוגמאות

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

 דוגמאות נוספות

all_gather

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מקשרת את הערכים של הטנסורים operands מכל תהליך לאורך all_gather_dim ויוצרת טנסורים מסוג results.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • operands...@receiver = [operand@sender for sender in process_group] לכל receiver ב-process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) לכל process ב-process_group.

קלט

תווית שם סוג מגבלות
(I1) operands מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1),‏ (C6)
(I2) all_gather_dim קבוע מסוג si64 (C1), (C6)
(I3) replica_groups קבוע טינסור דו-מימדי מסוג si64 (C2-C4)
(I4) channel_id קבוע מסוג si64 (C5)
(I5) use_global_device_ids קבוע מסוג i1 (C5)

פלט

שם סוג אילוצים
results מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C6)

אילוצים

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) אם use_global_device_ids = true, אז channel_id > 0.
  • (C6) type(results...) = type(operands...) מלבד:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 דוגמאות נוספות

all_reduce

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, מחילים פונקציית הפחתה computation על הערכים של הטנסורים operands מכל תהליך ויוצרים טנסורים results.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • results...@process[result_index] = exec(schedule) לעץ בינארי כלשהו schedule כאשר:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי שהוגדר בהטמעה, והטראברסה בסדר שלו היא to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

קלט

תווית שם סוג מגבלות
(I1) operands מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C5),‏ (C6)
(I2) replica_groups מספר משתנה של קבועים של טינסור חד-מימדי מסוג si64 (C1-C3)
(I3) channel_id קבוע מסוג si64 (C4)
(I4) use_global_device_ids קבוע מסוג i1 (C4)
(I5) computation פונקציה (C5)

פלט

שם סוג אילוצים
results מספר משתנה של טנסטורים או טנזורים כמותיים לכל טנזור (C6-C7)

אילוצים

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) מוגדר בתור:
    • num_replicas אם נעשה שימוש ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) אם הערך הוא use_global_device_ids = true, אז channel_id > 0.
  • (C5) השדה computation הוא מסוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 דוגמאות נוספות

all_to_all

סמנטיקה

all_to_all

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מפצלת את הערכים של הטנסורים operands לאורך split_dimension לחלקים, מפזרת את החלקים המפוצלים בין התהליכים, מקשרת את החלקים המפוזרים לאורך concat_dimension ויוצרת טנסורים מסוג results. הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0.
  • cross_partition(replica_groups) אם channel_id > 0.

לאחר מכן, בכל process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) לכל sender ב-process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] כאשר receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

קלט

תווית שם סוג מגבלות
(I1) operands מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1-C3),‏ (C9)
(I2) split_dimension קבוע מסוג si64 (C1),‏ (C2),‏ (C9)
(I3) concat_dimension קבוע מסוג si64 (C3),‏ (C9)
(I4) split_count קבוע מסוג si64 (C2),‏ (C4),‏ (C8),‏ (C9)
(I5) replica_groups קבוע טינסור דו-מימדי מסוג si64 (C5-C8)
(I6) channel_id קבוע מסוג si64

פלט

שם סוג אילוצים
results מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C9)

אילוצים

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) מוגדר כך:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...), אלא אם split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

דוגמאות נוספות

וגם

סמנטיקה

ביצוע AND של שני טינסורים lhs ו-rhs לפי רכיבים, יצירת טינסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • לערך בוליאני: לוגי AND.
  • למספרים שלמים: AND ברמת הביטים.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג בוליאני או שלם (C1)
(I2) rhs טינסור מסוג בוליאני או שלם (C1)

פלט

שם סוג אילוצים
result טינסור מסוג בוליאני או שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

דוגמאות נוספות

atan2

סמנטיקה

מבצעת פעולת atan2 מבחינת הרכיבים ב-lhs וב-Tensor rhs ומפיקה טינזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: atan2 מ-IEEE-754.
  • למספרים מרוכבים: complex atan2.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)
(I2) rhs טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 דוגמאות נוספות

batch_norm_grad

סמנטיקה

מחשבת הדרגתיות של מספר קלטים של batch_norm_training שמופצת לאחור מ-grad_output, ומפיקה טנזטורים grad_operand, grad_scale ו-grad_offset. באופן רשמי יותר, אפשר לבטא את הפעולה הזו מפירוק לפעולות StableHLO קיימות באמצעות תחביר Python באופן הבא:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1-C3),‏ (C5)
(I2) scale טינסור חד-מימדי מסוג נקודה צפה (floating-point) או מסוג טינסור מקודד לכל טינסור (C2),‏ (C4),‏ (C5)
(I3) mean טנזור חד-ממדי של נקודה צפה (floating-point) או סוג כמותי לכל טנזור (C2),‏ (C4)
(I4) variance טינסור חד-מימדי מסוג נקודה צפה (floating-point) או מסוג טינסור מקודד לכל טינסור (C2),‏ (C4)
(I5) grad_output טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C2),‏ (C3)
(I6) epsilon קבוע מסוג f32
(I7) feature_index קבוע מסוג si64 (C1),‏ (C5)

פלט

שם סוג אילוצים
grad_operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C2), (C3)
grad_scale טנזור חד-ממדי של נקודה צפה (floating-point) או סוג כמותי לכל טנזור (C2),‏ (C4)
grad_offset טינסור חד-מימדי מסוג נקודה צפה (floating-point) או מסוג טינסור מקודד לכל טינסור (C2),‏ (C4)

אילוצים

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) ל-operand, ל-scale, ל-mean, ל-variance, ל-grad_output, ל-grad_operand, ל-grad_scale ול-grad_offset יש אותו baseline_element_type.
  • (C3) operand, grad_output ו-grad_operand יש את אותה הצורה.
  • (C4) ל-scale, ל-mean, ל-variance, ל-grad_scale ול-grad_offset יש אותה צורה.
  • (C5) size(scale) = dim(operand, feature_index).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

סמנטיקה

נורמליזציה של הטנזור operand בכל המאפיינים מלבד המאפיין feature_index, וייצור טנזור result. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות StableHLO קיימות באמצעות תחביר Python באופן הבא:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

לסוגים כמותיים, הפונקציה מחזירה dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1-C7)
(I2) scale טינסור חד-מימדי מסוג נקודה צפה (floating-point) או מסוג טינסור מקודד לכל טינסור (C2),‏ (C3)
(I3) offset טנזור חד-ממדי של נקודה צפה (floating-point) או סוג כמותי לכל טנזור (C2),‏ (C4)
(I4) mean טינסור חד-מימדי מסוג נקודה צפה (floating-point) או מסוג טינסור מקודד לכל טינסור (C5)
(I5) variance טנזור חד-ממדי של נקודה צפה (floating-point) או סוג כמותי לכל טנזור (C2),‏ (C6)
(I6) epsilon קבוע מסוג f32
(I7) feature_index קבוע מסוג si64 (C1),‏ (C3-C6)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C2), (C7)

אילוצים

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) ל-operand, ל-scale, ל-offset, ל-mean, ל-variance ול-result יש אותו baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

סמנטיקה

הפונקציה מחשבת את הממוצע והשונות בכל המאפיינים, מלבד המאפיין feature_index, ומנורמלת את הטנזור operand כדי ליצור את הטנסורים output,‏ batch_mean ו-batch_var. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות StableHLO קיימות באמצעות תחביר Python באופן הבא:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)
(I2) scale טינסור חד-מימדי של נקודה צפה (floating-point) או טינסור שמקוונטיז לפי טינסור (C2),‏ (C3)
(I3) offset טינסור חד-מימדי של נקודה צפה (floating-point) או טינסור שמקוונטיז לפי טינסור (C2),‏ (C4)
(I4) epsilon קבוע מסוג f32 (C1),‏ (C3-C6)
(I5) feature_index קבוע מסוג si64 (C1),‏ (C3-C6)

פלט

שם סוג אילוצים
output t e n s o r f l o w או t e n s o r f l o w, (C7)
batch_mean טנזור חד-ממדי של נקודה צפה (floating-point) או לכל טנזור כמותי (C2),‏ (C5)
batch_var טינסור חד-מימדי של נקודה צפה (floating-point) או טינסור שמקוונטיז לפי טינסור (C2),‏ (C6)

אילוצים

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) ל-operand, ל-scale, ל-offset, ל-batch_mean, ל-batch_var ול-output יש את אותו baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

סמנטיקה

הפונקציה מבצעת פעולת bitcast על הטנזור operand ויוצרת טנזור result שבו הביטים של כל הטנזור operand מפורשים מחדש לפי הסוג של הטנזור result.

באופן רשמי יותר, בהינתן E = element_type(operand),‏ E' = element_type(result) ו-R = rank(operand):

  • אם num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • אם num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • אם הערך שלו הוא num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

הפונקציה bits מחזירה ייצוג בזיכרון של ערך נתון, וההתנהגות שלה מוגדרת בהתאם להטמעה כי הייצוג המדויק של הטנסורים מוגדר בהתאם להטמעה, וגם הייצוג המדויק של סוגי הרכיבים מוגדר בהתאם להטמעה.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C2)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C1-C2)

אילוצים

  • (C1) בהינתן E = is_quantized(operand) ? storage_type(operand) : element_type(operand),‏ E' = is_quantized(result) ? storage_type(result) : element_type(result) ו-R = rank(operand):
    • אם num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • אם num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • אם num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) אם is_complex(operand) or is_complex(result), אז is_complex(operand) and is_complex(result).

דוגמאות

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 דוגמאות נוספות

broadcast_in_dim

סמנטיקה

הרחבת המימדים ו/או הדירוג של מפריד קלט על ידי שכפול הנתונים במשתנה operand ויצירה של טנזור result. באופן רשמי יותר, result[result_index] = operand[operand_index], כאשר לכל d ב-axes(operand):

  • operand_index[d] = 0 אם dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C2),‏ (C5-C6)
(I2) broadcast_dimensions קבוע טינסור חד-מימדי מסוג si64 (C2-C6)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C1),‏ (C3),‏ (C5-C6)

מגבלות

  • (C1) הערך של element_type(result) מחושב לפי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) חוץ מהבדל אחד: quantization_dimension(operand), scales(operand) ו-zero_points(operand), יכול להיות שונה מ-quantization_dimension(result), מ-scales(result) ומ-zero_points(result), אחרת.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) לכל d ב-axes(operand):
    • dim(operand, d) = 1 או
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) אם is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • אם dim(operand, quantization_dimension(operand)) = 1, אז scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

דוגמאות

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 דוגמאות נוספות

כיסוי

סמנטיקה

הפונקציה יוצרת את הפלט מהרצה של פונקציה אחת בלבד מ-branches, בהתאם לערך של index. באופן רשמי יותר, result = selected_branch() כאשר:

  • selected_branch = branches[index] אם 0 <= index < size(branches).
  • selected_branch = branches[-1] אחרת.

קלט

תווית שם סוג מגבלות
(I1) index טינסור של מימד 0 מסוג si32
(I2) branches מספר פונקציות וריאדי (C1-C4)

פלט

שם סוג אילוצים
results מספר משתנה של טינסורים, טינסורים מרובים או אסימונים (C4)

אילוצים

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

דוגמאות

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

 דוגמאות נוספות

cbrt

סמנטיקה

הפונקציה מבצעת פעולת שורש חזק שלישי של כל רכיב בטנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: rootn(x, 3) מ-IEEE-754.
  • למספרים מרוכבים: שורש מעוקבים מורכב.
  • לסוגים מרובעניים: dequantize_op_quantize(cbrt, operand, type(result))

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

 דוגמאות נוספות

ceil

סמנטיקה

הפונקציה מבצעת חישוב של חזקה שלם של רכיבי הטנזור operand ומפיקה טנזור result. הטמעה של הפעולה roundToIntegralTowardPositive מהמפרט של IEEE-754. לסוגים כמותיים, הפונקציה מחזירה dequantize_op_quantize(ceil, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 דוגמאות נוספות

cholesky

סמנטיקה

חישוב הפירוק של Cholesky של קבוצה של מטריצות.

באופן רשמי יותר, לכל i ב-index_space(result), result[i0, ..., iR-3, :, :] הוא פירוק של כולסקי של a[i0, ..., iR-3, :, :], בצורת מטריצה בצורת משולש תחתון (אם lower הוא true) או משולש עליון (אם lower הוא false). ערכי הפלט במשולש הנגדי, כלומר המשולש העליון המחמיר או המשולש התחתון המחמיר בהתאמה, מוגדרים על פי ההטמעה.

אם קיים i שבו מטריצת הקלט היא לא מטריצת Hermitian חיובית-מוגדרת, ההתנהגות לא מוגדרת.

לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a t e n s o r f l o w, או t e n s o r f l o w, או t e tensor tenor (C1-C3)
(I2) lower קבוע טינסור של מימד 0 מסוג i1

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

דוגמאות

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

הצמדה

סמנטיקה

מצמידה כל רכיב של טנזור operand בין ערך מינימלי ומקסימלי, ויוצרת טנזור result. באופן רשמי יותר, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), כאשר min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize(clamp, min, operand, max, type(result)).

הטמעת סדר במספרים מורכבים כרוכה בסמינטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מורכבים לפעולה הזו (#560).

קלט

תווית שם סוג מגבלות
(I1) min טינסור או טינסור מותאם (quantized) לכל טינסור (C1), (C3)
(I2) operand טנזור או לכל טנזור קוונטי (C1-C4)
(I3) max טינסור או טינסור מותאם (quantized) לכל טינסור (C2),‏ (C3)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C4)

אילוצים

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

דוגמאות

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

דוגמאות נוספות

collective_broadcast

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, שולחים את הערך של הטנזור operand מתהליך המקור לתהליכי היעד ויוצרים טנזור result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0.
  • cross_partition(replica_groups) אם channel_id > 0.

לאחר מכן, result@process מחושב לפי הנוסחה הבאה:

  • operand@process_groups[i, 0] אם קיים i כזה שהתהליך נמצא ב-process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או לכל טנזור קוונטי (C3)
(I2) replica_groups מספר משתנה של קבועים של טינסור חד-מימדי מסוג si64 (C1),‏ (C2)
(I3) channel_id קבוע מסוג si64

פלט

שם סוג מגבלות
result טנזור או לכל טנזור קוונטי (C3)

מגבלות

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N כאשר N מוגדר בתור:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • (C3) type(result) = type(operand).

דוגמאות

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

סמנטיקה

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, שולחים את הערך של המכפלה operand מתהליך המקור לתהליך היעד ויוצרים מכפלה result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(source_target_pairs) אם channel_id <= 0.
  • cross_partition(source_target_pairs) אם channel_id > 0.

לאחר מכן, result@process ניתן על ידי:

  • operand@process_groups[i, 0], אם קיים i כזה process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או לכל טנזור קוונטי (C5)
(I2) source_target_pairs קבוע טינסור דו-מימדי מסוג si64 (C1-C4)
(I3) channel_id קבוע מסוג si64

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

אילוצים

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, כאשר N מוגדר בתור:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_partitions אם משתמשים ב-cross_partition.
  • (C5) type(result) = type(operand).

דוגמאות

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

 דוגמאות נוספות

compare

סמנטיקה

הפונקציה מבצעת השוואה של הרכיבים של הטנסורים lhs ו-rhs לפי comparison_direction ו-compare_type, ומפיקה טנסור result.

הערכים של comparison_direction ו-compare_type הם:

לסוגים של רכיבים בוליאניים ומספריים שלמים:

  • EQ:‏ lhs = rhs.
  • NE:‏ lhs != rhs.
  • GE:‏ lhs >= rhs.
  • GT:‏ lhs > rhs.
  • LE:‏ lhs <= rhs.
  • LT:‏ lhs < rhs.

לסוגי רכיבים של נקודה צפה עם compare_type = FLOAT, הפעולה מיישמת את הפעולות הבאות של IEEE-754:

  • EQ:‏ compareQuietEqual.
  • NE:‏ compareQuietNotEqual.
  • GE:‏ compareQuietGreaterEqual.
  • GT:‏ compareQuietGreater.
  • LE:‏ compareQuietLessEqual.
  • LT:‏ compareQuietLess.

לסוגים של רכיבים עם נקודה צפה עם compare_type = TOTALORDER, הפעולה משתמשת בשילוב של פעולות totalOrder ו-compareQuietEqual מ-IEEE-754.

בסוגי רכיבים מורכבים, מתבצעת השוואה לקסיקוגרפית של זוגות (real, imag) באמצעות comparison_direction ו-compare_type שסופקו. קביעת סדר של מספרים מרוכבים כרוכה בסמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים כשהערך של comparison_direction הוא GE, GT, LE או LT (#560).

לסוגי נתונים שמקובצים. מבצעת את הפונקציה dequantize_compare(lhs, rhs, comparison_direction).

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1-C3)
(I2) rhs טנזור או לכל טנזור קוונטי (C1-C2)
(I3) comparison_direction enum של EQ, NE, GE, GT, LE וגם LT
(I4) compare_type enum של FLOAT,‏ TOTALORDER,‏ SIGNED וגם UNSIGNED (C3)

פלט

שם סוג אילוצים
result טינסור מסוג בוליאני (C2)

אילוצים

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type מוגדר בתור:
    • SIGNED אם is_signed_integer(element_type(lhs)).
    • UNSIGNED אם is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT או TOTALORDER אם is_float(element_type(lhs)).
    • FLOAT אם is_complex(element_type(lhs)).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

 דוגמאות נוספות

מורכב

סמנטיקה

ביצוע המרה של רכיבים לערך מורכב מזוג ערכים אמיתיים ומדומים, lhs ו-rhs, יצירת טינסור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג f32 או f64 (C1-C3)
(I2) rhs טינסור מסוג f32 או f64 (C1)

פלט

שם סוג אילוצים
result t e n s o r f l o w (C2),‏ (C3)

אילוצים

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) מסוג complex<E> כאשר E = element_type(lhs).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

 דוגמאות נוספות

מורכב

סמנטיקה

כוללת פעולה שמורכבת (מורכבת) מפעולות StableHLO אחרות, שלוקחת inputs ו-composite_attributes ומפיקה results. הסמנטיקה של הפעולה מיושמת באמצעות המאפיין decomposition. אפשר להחליף את הפעולה composite בפירוק שלה בלי לשנות את הסמנטיקה של התוכנית. במקרים שבהם הטמעת הפירוק בקוד לא מספקת את אותה סמנטיקה של הפעולה, עדיף להשתמש ב-custom_call.

השדה version (ברירת המחדל היא 0) משמש לציון מתי הסמנטיקה של רכיב מורכב משתנה.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של ערכים
(I2) name קבוע מסוג string
(I3) composite_attributes מילון מאפיינים
(I4) decomposition קבוע מסוג string
(I5) version קבוע מסוג si32

פלט

שם סוג
results מספר משתנה של ערכים

אילוצים

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

דוגמאות

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

 דוגמאות נוספות

לשרשר

סמנטיקה

שרשור של inputs לאורך המאפיין dimension באותו הסדר של הארגומנטים שצוינו, יצירת טינסור result. באופן רשמי יותר, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], כאשר:

  1. id = d0 + ... + dk-1 + kd.
  2. הערך של d שווה ל-dimension, ו-d0, … הם גדלי המאפיין ה-d של inputs.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1-C6)
(I2) dimension קבוע מסוג si64 (C2),‏ (C4),‏ (C6)

פלט

שם סוג מגבלות
result טנזור או לכל טנזור קוונטי (C5-C6)

אילוצים

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) למעט dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) למעט:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

דוגמאות

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 דוגמאות נוספות

קבוע

סמנטיקה

הפונקציה יוצרת טינסור output מקבוע value.

קלט

תווית שם סוג מגבלות
(I1) value קבוע (C1)

פלט

שם סוג אילוצים
output טינסור או טינסור מצטבר (C1)

אילוצים

  • (C1) type(value) = type(output).

דוגמאות

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

 דוגמאות נוספות

להשלים המרה

סמנטיקה

מבצעת המרה ברמת הרכיבים מסוג רכיב אחד לסוג אחר ב-operand Tensor ומפיקה טנזור result.

להמרות מסוג boolean-to-any-supported-type, הערך false מומר לאפס והערך true מומר ל-1. בהמרות מסוג any-supported-type-to-boolean, ערך אפס מומר לערך false, וערכים שאינם אפס מומרים לערך true. בהמשך מוסבר איך לעשות את זה עם סוגים מורכבים.

בהמרות שכוללות מספר שלם למספר שלם, מספר שלם לנקודת צפה או נקודת צפה לנקודת צפה, אם אפשר לייצג בדיוק את ערך המקור בסוג היעד, ערך התוצאה הוא הייצוג המדויק הזה. אחרת, אופן הפעולה הזה טרם נקבע (#180).

בהמרות שכוללות מספרים בספרות עשרוניות למספרים שלמים, החלק השברי נקטע. אם אי אפשר לייצג את הערך המקוצר בסוג היעד, ההתנהגות תהיה 'בהמתנה' (#180).

המרות עם מורכב למורכב פועלות באותו אופן של המרות מנקודה צפה לנקודה צפה (floating-point) כדי להמיר חלקים אמיתיים ומדומים.

בהמרות מסוג מורכב לכל סוג אחר והמרות מסוג כל סוג אחר למורכב, המערכת מתעלמת מהערך הדמיוני של המקור או שהערך הדמיוני של היעד מאופס, בהתאמה. ההמרה של החלק הממשי מתבצעת אחרי ההמרות של הנקודה הצפה.

באופן עקרוני, הפעולה הזו יכולה לבטא ביטול קצירה (המרה מטנסורים מקובצים לטנסורים רגילים), קצירה (המרה מטנסורים רגילים לטנסורים מקובצים) וקצירה מחדש (המרה בין טנסורים מקובצים), אבל כרגע יש לנו פעולות ייעודיות לכך – uniform_dequantize לתרחיש לדוגמה הראשון ו-uniform_quantize לתרחישי לדוגמה השני והשלישי. בעתיד, יכול להיות ששתי הפעולות האלה ימוזגו ל-convert (#1576).

קלט

תווית שם סוג מגבלות
(I1) operand טרנספורמציה (C1)

פלט

שם סוג אילוצים
result טרנספורמציה (C1)

מגבלות

  • (C1) shape(operand) = shape(result).

דוגמאות

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

דוגמאות נוספות

convolve

סמנטיקה

הפונקציה מחשבת מכפלות סקלריות בין חלונות של lhs לבין פרוסות של rhs, ומייצרת את הערך result. בתרשים הבא אפשר לראות איך מחושבים יסודות ב-result מ-lhs ומ-rhs באמצעות דוגמה קונקרטית.

convolve

באופן רשמי יותר, אפשר להציג מחדש את הקלט במונחים של lhs כדי להביע חלונות של lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

כדי לבצע את שינוי הפריים הזה, נעשה שימוש בפונקציות העזר הבאות:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] כאשר j[d] = i[permutation[d]].

אם הערך הוא feature_group_count = 1 וגם batch_group_count = 1, אז לכל output_spatial_index ב-index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product כאשר:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). נראה שהתכונה הזו לא בשימוש, ולכן אנחנו מתכננים להסיר אותה בעתיד (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

אם feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

אם batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

לסוגי נתונים שמסומנים בקונטיינרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

לסוגי נתונים מרומזים היברידיים, הפונקציה מבצעת את הפעולה hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C10-C11),‏ (C14) (C25),‏ (C27-C28),‏ (C31-C32),‏ (C34)
(I2) rhs t e n s o r f l o w, או t e n s o r f l o w, (C1),‏ (C14-C16),‏ (C25),‏ (C27-C29),‏ (C31-C34)
(I3) window_strides קבוע טינסור חד-מימדי מסוג si64 (C2-C3), (C25)
(I4) padding קבוע טינסור דו-מימדי מסוג si64 (C4),‏ (C25)
(I5) lhs_dilation קבוע טינסור חד-מימדי מסוג si64 (C5-C6), (C25)
(I6) rhs_dilation קבוע טינסור חד-מימדי מסוג si64 (C7-C8), (C25)
(I7) window_reversal קבוע טינסור חד-מימדי מסוג i1 (C9)
(I8) input_batch_dimension קבוע מסוג si64 (C10),‏ (C13),‏ (C25)
(I9) input_feature_dimension קבוע מסוג si64 (C11),‏ (C13-C14)
(I10) input_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension קבוע מסוג si64 (C14), (C18)
(I12) kernel_output_feature_dimension קבוע מסוג si64 (C15-C16),‏ (C18),‏ (C25),‏ (C29)
(I13) kernel_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C17-C18),‏ (C25)
(I14) output_batch_dimension קבוע מסוג si64 (C20),‏ (C25)
(I15) output_feature_dimension קבוע מסוג si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C19-C20),‏ (C25)
(I17) feature_group_count קבוע מסוג si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count קבוע מסוג si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config מספר ושונות של טיפוסים בני מנייה (enum) של DEFAULT, HIGH ו-HIGHEST (C24)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C25-C28),‏ (C30),‏ (C32-34)

מגבלות

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) בהינתן input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) בהינתן kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) בהינתן output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) מוגדר בתור:
    • dim(lhs, input_batch_dimension) / batch_group_count אם result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) אם result_dim = output_feature_dimension.
    • num_windows אחרת, כאשר:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • אם הפעולה משתמשת בטנסורים לא מקובצים:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • אם הפעולה משתמשת במקדמי אחיזה כמותיים:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) אם is_per_axis_quantized(result), אז quantization_dimension(result) = output_feature_dimension.
    • אם is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

דוגמאות

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 דוגמאות נוספות

קוסינוס

סמנטיקה

הפונקציה מבצעת פעולת קוסינוס לפי רכיבים בטנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • למספרים מסוג float: cos מ-IEEE-754.
  • למספרים מרוכבים: קוסינוס מרוכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(cosine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 דוגמאות נוספות

count_leading_zeros

סמנטיקה

הפונקציה מבצעת ספירה מבחינת הרכיבים של מספר אפס הביטים המובילים בטנזור operand ומפיקה את טנזור result.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טינסור מסוג מספר שלם (C1)

אילוצים

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

דוגמאות נוספות

custom_call

סמנטיקה

האנקפסולציה של פעולה call_target_name שמוגדרת בהטמעה, שמקבלת את הערכים inputs ו-called_computations ומפיקה את הערך results. אפשר להשתמש ב-has_side_effect, ב-backend_config וב-api_version כדי לספק מטא-נתונים נוספים שמוגדרים בהטמעה.

בשלב הזה, הפעולה הזו מכילה אוסף לא מאורגן למדי של מטא-נתונים, שמשקף את ההתפתחות האורגנית של הפעולה המקבילה שלה במהדר XLA. בעתיד אנחנו מתכננים לאחד את המטא-נתונים האלה (#741).

קלט

תווית שם סוג
(I1) inputs מספר משתנה של ערכים
(I2) call_target_name קבוע מסוג string
(I3) has_side_effect קבוע מסוג i1
(I4) backend_config קבוע מסוג string או מילון מאפיינים
(I5) api_version קבוע מסוג si32
(I6) called_computations מספר משתנה של קבועים מסוג string

פלט

שם סוג
results מספר משתנה של ערכים

דוגמאות

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

חילוק

סמנטיקה

הפונקציה מבצעת חילוק של המכפיל lhs במכפיל rhs לפי רכיבים, ומפיקה את הטנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים שלמים: חלוקת מספרים שלמים שמניבה את החלק האלגברי, בלי החלק העשרוני.
  • למספרים מסוג float: division מ-IEEE-754.
  • למספרים מרוכבים: חלוקה מורכבת.
  • לסוגים מרובעניים:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o r f l o w, או t tensor quantor, (C1)
(I2) rhs טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 דוגמאות נוספות

dot_general

סמנטיקה

הפונקציה מחשבת מכפלות נקודה בין פרוסות של lhs לבין פרוסות של rhs ומפיקה טינסור result.

באופן רשמי יותר, result[result_index] = dot_product, כאשר:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index כאשר size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) ו-size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

לסוגי נתונים שמסומנים בקונטיינרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

לסוגי נתונים מרומזים היברידיים, הפונקציה מבצעת את הפעולה hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config קובע את האיזון בין המהירות לדיוק של החישובים בקצוות העורפי של המאיצים. הערך יכול להיות אחד מהערכים הבאים (כרגע, הסמנטיקה של ערכי המניין האלה לא מוגדרת במלואה, אבל אנחנו מתכננים לטפל בבעיה הזו בבקשה מספר 755):

  • DEFAULT: החישוב המהיר ביותר, אבל ההערכה הכי פחות מדויקת למספר המקורי.
  • HIGH: חישוב איטי יותר, אבל קירוב מדויק יותר למספר המקורי.
  • HIGHEST: החישוב האיטי ביותר, אבל ההערכה המדויקת ביותר למספר המקורי.

השדה DotAlgorithm מגדיר את המאפיינים העיקריים של האלגוריתם שמשמש להטמעת פעולת הנקודה, והוא גם מגדיר את הדיוק. אם השדות של מאפיין האלגוריתם מוגדרים, הערך של precision_config חייב להיות DEFAULT. DotAlgorithms לא יש ערך ברירת מחדל, כי פרמטרים שמוגדרים כברירת מחדל מוגדרים בהטמעה. לכן, אפשר להגדיר את כל השדות של אלגוריתם הנקודות לערך None כדי לציין אלגוריתם נקודות ריק, שישתמש במקום זאת בערך precision_config.

השדות DotAlgorithm כוללים את אלה:

  • lhs_precision_type ו-rhs_precision_type, רמות הדיוק שאליהם מתבצעת העיגול של הצד הימני והשמאלי של הפעולה. סוגי הדיוק לא תלויים בסוגי האחסון של מקורות הקלט והפלט.
  • accumulation_type רמת הדיוק שמשמשת לצבירה.
  • הערכים lhs_component_count,‏ rhs_component_count ו-num_primitive_operations רלוונטיים כשאנחנו משתמשים באלגוריתם שמפרק את הצד הימני או הימני של הביטוי לרכיבים מרובים ומבצע כמה פעולות 'פרימיטיביות' של מכפלת מטריצות על הערכים האלה – בדרך כלל כדי לחקות רמת דיוק גבוהה יותר (למשל, שימוש ב-bfloat16 כסוג נתונים של בינה מלאכותית לצורך חישובים ברמת דיוק גבוהה יותר:‏ bf16_6x,‏ tf32_3x וכו'). באלגוריתמים שלא כוללים פירוק, צריך להגדיר את הערכים האלה ל-1.
  • allow_imprecise_accumulation כדי לציין אם מותר לצבור במדויקות נמוכה יותר בחלק מהשלבים (למשל CUBLASLT_MATMUL_DESC_FAST_ACCUM).

דוגמאות למאפייני DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

בעלי ההטמעות מחליטים אילו שילובים יהיו נתמכים. באופן כללי, לא בטוח שהצרכן של StableHLO ייתמך כל אלגוריתם בכל סוג מאיץ. אם אין תמיכה באלגוריתם מסוים, צריכה להיות שגיאה ולא חזרה לחלופה. האימות של StableHLO יספק אימות לפי היכולת הטובה ביותר, כדי למנוע שימוש באלגוריתמים שלא ידועה תמיכה בהם בכל חומרה.

ראו xla_data.proto > Algorithm לגבי חלק מערכי האלגוריתם הנתמכים. כרטיס מספר 2483 מתעד את התוכנית ליצור מסמך מרכזי בנושא אלגוריתמים נתמכים לפי קצה עורפי.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור או טינסור מותאם (quantized) לכל טינסור (C5-C6), (C9-C10), (C12-C14), (C17-C18) (C20)
(I2) rhs טינסור או טינסור מצטבר (C7-C10),‏ (C12-C20)
(I3) lhs_batching_dimensions קבוע טינסור חד-מימדי מסוג si64 (C1),‏ (C3),‏ (C5),‏ (C9),‏ (C12)
(I4) rhs_batching_dimensions קבוע מפריד חד-ממדי מסוג si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions קבוע מפריד חד-ממדי מסוג si64 (C2),‏ (C3),‏ (C6),‏ (C10)
(I6) rhs_contracting_dimensions קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C4),‏ (C8),‏ (C10),‏ (C16)
(I7) precision_config מספר משתנה של ערכים מתוך הממשק DEFAULT, HIGH ו-HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType או TensorFloat32 (C21)
(I9) rhs_precision_type FloatType או TensorFloat32 (C21)
(I10) accumulation_type FloatType או TensorFloat32 (C21)
(I11) lhs_component_count קבוע מסוג si32 (C21),‏ (C22)
(I12) rhs_component_count קבוע מסוג si32 (C21),‏ (C23)
(I13) num_primitive_operations קבוע מסוג si32 (C21), (C24)
(I14) allow_imprecise_accumulation קבוע מסוג bool (C21)

פלט

שם סוג אילוצים
result t e n s o r f l o w, או t e n s o r f l o w, (C12), (C14), (C18-C20)

אילוצים

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • אם הפעולה משתמשת בטנסורים לא מקובצים:
    • (C13) element_type(lhs) = element_type(rhs).
  • אם הפעולה משתמשת בטנסורים מרובים (quantization):
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) לא נמצא ב-rhs_contracting_dimensions.
    • אם is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • אם !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

דוגמאות

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 דוגמאות נוספות

dynamic_broadcast_in_dim

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית ל-broadcast_in_dim, אבל צורת התוצאה מוגדרת באופן דינמי דרך output_dimensions.

הפעולה מקבלת גם את המאפיינים האופציונליים known_expanding_dimensions, known_nonexpanding_dimensions כדי להביע ידע סטטי לגבי התנהגות ההרחבה של המאפיינים. אם לא מציינים זאת, ההנחה היא שכל המאפיינים מתרחבים.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C2), (C5-C6), (C9)
(I2) output_dimensions טינסור חד-מימדי מסוג מספר שלם (C7)
(I3) broadcast_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C2-C6)
(I4) known_expanding_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C8-C9)
(I5) known_nonexpanding_dimensions טנזור קבוע חד-ממדי מסוג מספר שלם (C8-C9)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C1),‏ (C3),‏ (C5-C7)

מגבלות

  • (C1) הערך של element_type(result) מחושב לפי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) חוץ מהבדל אחד: quantization_dimension(operand), scales(operand) ו-zero_points(operand), יכול להיות שונה מ-quantization_dimension(result), מ-scales(result) ומ-zero_points(result), אחרת.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) לכל d ב-axes(operand):
    • dim(operand, d) = 1 או
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) אם is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • אם dim(operand, quantization_dimension(operand)) = 1, אז scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

דוגמאות

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 דוגמאות נוספות

dynamic_conv

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית לפעולת convolution, אבל המילוי מצוין באופן דינמי באמצעות padding.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31) (C33)
(I2) rhs t e n s o r f l o w, או t e n s o r f l o w, (C1),‏ (C14-C16),‏ (C26-C28),‏ (C30-C33)
(I3) padding טנזור דו-ממדי של מספר שלם (C4)
(I4) window_strides קבוע טינסור חד-מימדי מסוג si64 (C2-C3)
(I5) lhs_dilation קבוע טינסור חד-מימדי מסוג si64 (C5-C6)
(I6) rhs_dilation קבוע מפריד חד-ממדי מסוג si64 (C7-C8)
(I7) window_reversal קבוע טינסור חד-מימדי מסוג i1 (C9)
(I8) input_batch_dimension קבוע מסוג si64 (C10),‏ (C13)
(I9) input_feature_dimension קבוע מסוג si64 (C11),‏ (C13-C14)
(I10) input_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C12), (C13)
(I11) kernel_input_feature_dimension קבוע מסוג si64 (C14), (C18)
(I12) kernel_output_feature_dimension קבוע מסוג si64 (C15-C16),‏ (C18),‏ (C28)
(I13) kernel_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C17-C18)
(I14) output_batch_dimension קבוע מסוג si64 (C20)
(I15) output_feature_dimension קבוע מסוג si64 (C20), (C29)
(I16) output_spatial_dimensions קבוע טינסור חד-מימדי מסוג si64 (C19-C20)
(I17) feature_group_count קבוע מסוג si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count קבוע מסוג si64 (C10), (C15), (C22), (C23)
(I19) precision_config מספר ושונות של טיפוסים בני מנייה (enum) של DEFAULT, HIGH ו-HIGHEST (C24)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C25-C27),‏ (C29),‏ (C31-C33)

אילוצים

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) בהינתן input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) בהינתן kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) בהינתן output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) מוגדר בתור:
    • dim(lhs, input_batch_dimension) / batch_group_count אם result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) אם result_dim = output_feature_dimension.
    • num_windows אחרת, כאשר:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • אם הפעולה משתמשת בטנסורים לא מקובצים:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • אם הפעולה משתמשת במקדמי אחיזה כמותיים:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) אם is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) אם is_per_axis_quantized(result), אז quantization_dimension(result) = output_feature_dimension.
    • אם is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) אם is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • אם !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

דוגמאות

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

דוגמאות נוספות

dynamic_gather

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית ל-gather, כאשר slice_sizes מצוין באופן דינמי כערך.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C7),‏ (C10-C12),‏ (C14)
(I2) start_indices Tensor מסוג מספר שלם (C2),‏ (C3),‏ (C13)
(I3) slice_sizes זווית חד-ממדית של מספר שלם מסוג מספר שלם (C8),‏ (C11-C13)
(I4) offset_dims קבוע טינסור חד-מימדי מסוג si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims קבוע מפריד חד-ממדי מסוג si64 (C1),‏ (C6-C8),‏ (C13)
(I6) start_index_map קבוע טינסור חד-מימדי מסוג si64 (C3),‏ (C9),‏ (C10)
(I7) index_vector_dim קבוע מסוג si64 (C2), (C3), (C13)
(I8) indices_are_sorted קבוע מסוג i1

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C5), (C13-C14)

אילוצים

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) כאשר:
    • batch_dim_sizes = shape(start_indices), מלבד העובדה שגודל המאפיין של start_indices שתואם ל-index_vector_dim לא נכלל.
    • offset_dim_sizes = shape(slice_sizes), מלבד העובדה שגודל המאפיינים ב-slice_sizes שתואם ל-collapsed_slice_dims לא נכלל.
    • הפונקציה combine ממפה את batch_dim_sizes לצירים התואמים ל-batch_dims ואת offset_dim_sizes לצירים התואמים ל-offset_dims.
  • (C14) element_type(operand) = element_type(result).

דוגמאות

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

דוגמאות נוספות

dynamic_iota

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית לפעולת iota, אבל צורת התוצאה מצוינה באופן דינמי באמצעות output_shape.

קלט

תווית שם סוג מגבלות
(I1) output_shape זווית חד-ממדית של מספר שלם מסוג מספר שלם (C1),‏ (C2)
(I2) iota_dimension si64 (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C2)

מגבלות

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

דוגמאות

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

דוגמאות נוספות

dynamic_pad

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית לפעולת pad, אבל הערכים של edge_padding_low,‏ edge_padding_high ו-interior_padding מוגדרים באופן דינמי.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C2),‏ (C4)
(I2) padding_value טינסור של מימד 0 או טינסור מקודד לכל טינסור (C1)
(I3) edge_padding_low טינסור חד-מימדי מסוג מספר שלם (C1), (C4)
(I4) edge_padding_high טינסור חד-מימדי מסוג מספר שלם (C1), (C4)
(I5) interior_padding זווית חד-ממדית של מספר שלם מסוג מספר שלם (C2-C4)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C3-C6)

אילוצים

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

דוגמאות

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 דוגמאות נוספות

dynamic_reshape

סמנטיקה

הפעולה הזו זהה מבחינה פונקציונלית לעיצוב מחדש, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_shape.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C3)
(I2) output_shape טינסור חד-מימדי מסוג מספר שלם (C4)

פלט

שם סוג אילוצים
result טינסור או טינסור מצטבר (C1-C4)

אילוצים

  • (C1) הערך של element_type(result) מחושב לפי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand), אלא ש-quantization_dimension(operand) ו-quantization_dimension(result) עשויים להיות שונים.
  • (C2) size(operand) = size(result).
  • (C3) אם is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

 דוגמאות נוספות

dynamic_slice

סמנטיקה

הפונקציה מחלצת פרוסה מ-operand באמצעות אינדקסים ראשוניים שמחושבים באופן דינמי, ויוצרת טינסור result. העמודה start_indices מכילה את אינדקסי ההתחלה של הפלחים לכל מאפיין שעשוי לעבור התאמה, והעמודה slice_sizes מכילה את הגדלים של הפלחים לכל מאפיין. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C2),‏ (C4)
(I2) start_indices מספר משתנה של טינסורים ב-0 מימדים מסוג מספר שלם (C2),‏ (C3)
(I3) slice_sizes קבוע מפריד חד-ממדי מסוג si64 (C2), (C4), (C5)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C5)

אילוצים

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

דוגמאות

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 דוגמאות נוספות

dynamic_update_slice

סמנטיקה

הפונקציה יוצרת טינסור result שווה לטינסור operand, מלבד העובדה שהפרוסה שמתחילה ב-start_indices מתעדכנת בערכים של update. באופן רשמי יותר, result[result_index] מוגדר כך:

  • update[update_index] אם 0 <= update_index < shape(update) כאשר:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • אחרת, operand[result_index].

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1-C4), (C6)
(I2) update טינסור או טינסור מותאם (quantized) לכל טינסור (C2), (C3), (C6)
(I3) start_indices מספר משתנה של טינסורים ב-0 מימדים מסוג מספר שלם (C4), (C5)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

אילוצים

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

דוגמאות

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

 דוגמאות נוספות

מעריכיות

סמנטיקה

הפונקציה מבצעת פעולה מעריכית לפי רכיבים בטנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • למספרים מסוג float: exp מ-IEEE-754.
  • למספרים מרוכבים: מעריכי מרוכבים.
  • לסוגים הבאים לפי כמות: dequantize_op_quantize(exponential, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 דוגמאות נוספות

exponential_minus_one

סמנטיקה

מבצעת פעולה מעריכית מבחינת רכיבים מינוס פעולה אחת על טנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: expm1 מ-IEEE-754.
  • למספרים מרוכבים: מעריכי מרוכבים פחות 1.
  • לסוגים הבאים לפי כמות: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

 דוגמאות נוספות

fft

סמנטיקה

ביצוע התמרות פורייה ישירות ובעקבותיה הפוכות עבור קלט/פלט אמיתיים ומעורבים.

fft_type הוא אחד מהבאים:

  • FFT: FFT קדימה מסוג מורכב-מורכב.
  • IFFT: קובץ FFT מרוכב הופכי למורכב.
  • RFFT: FFT קדימה מערכים ריאליים למערכים מורכבים.
  • IRFFT: FFT הפוך מ-real ל-complex (כלומר, הפונקציה מקבלת ערכים מורכבים ומחזירה ערכים אמיתיים).

באופן רשמי יותר, בהינתן הפונקציה fft שמקבלת כקלט טנסורים חד-ממדיים מסוגים מורכבים, יוצרת טנסורים חד-ממדיים מאותו סוג כפלט ומחשבת את טרנספורמציית Fourier הדיסקרטית:

בשביל fft_type = FFT, result מוגדר כתוצאה הסופית של סדרת חישובי L כאשר L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

בנוסף, בהינתן הפונקציה ifft שיש לה אותה חתימה של סוג ומחשבת את הפונקציה ההפוכה של fft:

עבור fft_type = IFFT, result מוגדר כהפוך של החישובים של fft_type = FFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

בנוסף, הפונקציה rfft מקבלת טנסורים חד-ממדיים של סוגי נקודה צפה, יוצרת טנסורים חד-ממדיים של סוגי נתונים מורכבים באותה סמנטיקה של נקודה צפה ופועלת באופן הבא:

  • rfft(real_operand) = truncated_result איפה
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(כשמחושב טרנספורמציית פורייה מובחנת לאופרנדים אמיתיים, רכיבי N/2 + 1 הראשונים של התוצאה מגדירים באופן חד-משמעי את שאר התוצאה, כך שהתוצאה של rfft תקוצר כדי להימנע מחישוב של רכיבים מיותרים).

עבור fft_type = RFFT, הערך result מוגדר כתוצאה הסופית של סדרה של חישובי L, כאשר L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

לבסוף, בהינתן הפונקציה irfft שיש לה אותה חתימה של סוג ומחשבת את ההפכי של rfft:

עבור fft_type = IRFFT, result מוגדר כהפוך של החישובים של fft_type = RFFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor של נקודה צפה (floating-point) או סוג מורכב (C1),‏ (C2),‏ (C4),‏ (C5)
(I2) fft_type enum של FFT,‏ IFFT,‏ RFFT וגם IRFFT (C2),‏ (C5)
(I3) fft_length קבוע טינסור חד-מימדי מסוג si64 (C1),‏ (C3),‏ (C4)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או מסוג מורכב (C2),‏ (C4),‏ (C5)

אילוצים

  • (C1) size(fft_length) <= rank(operand).
  • (C2) הקשר בין סוגי הרכיבים operand ו-result משתנה:
    • אם fft_type = FFT, element_type(operand) ו-element_type(result) הם מאותו סוג מורכב.
    • אם ל-fft_type = IFFT, ל-element_type(operand) ול-element_type(result) יש אותו סוג מורכב.
    • אם fft_type = RFFT,‏ element_type(operand) הוא סוג של נקודת צפה ו-element_type(result) הוא סוג מורכב עם אותה סמנטיקה של נקודת צפה.
    • אם fft_type = IRFFT, ‏ element_type(operand) הוא סוג מורכב ו-element_type(result) הוא סוג של נקודת צפה עם אותה סמנטיקה של נקודת צפה.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) אם בין operand ל-result יש טינסור real מסוג נקודה צפה, אז shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) מלבד:
    • אם fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • אם fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

דוגמאות

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

קומה

סמנטיקה

הפונקציה מבצעת חישוב של ⌊x⌋ לאלמנטים של הטנזור operand ומפיקה טנזור result. הטמעה של הפעולה roundToIntegralTowardNegative מהמפרט של IEEE-754. לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize(floor, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 דוגמאות נוספות

לאסוף

סמנטיקה

הפונקציה אוספת פרוסות מטנזור operand מההיסטים שצוינו ב-start_indices ומפיקה טנזור result.

בתרשים הבא מוצגת דוגמה קונקרטית לאופן שבו רכיבים ב-result ממפים לרכיבים ב-operand. בתרשים מוצגות כמה דוגמאות למדדי result, ומוסבר בפירוט לאילו מדדי operand הם תואמים.

איסוף

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • הערך start_index מוגדר בתור:
    • start_indices[bi0, ..., :, ..., biN] כאשר bi הם רכיבים נפרדים ב-batch_index ו-: מוחדר באינדקס index_vector_dim, אם index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] אחרת.
  • ל-d_operand ב-axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) אם d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 אחרת.
  • עבור d_operand ב-axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] אם d_operand = operand_batching_dims[i_batching] וגם d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 אחרת.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] כאשר oi הם רכיבים נפרדים ב-offset_index, ו-0 מוחדר באינדיקטורים מ-collapsed_slice_dims ומ-operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

אם הערך של indices_are_sorted הוא true, אז ההטמעה יכולה להניח שהערכים start_indices ממוינים ביחס ל-start_index_map, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C8),‏ (C11),‏ (C17),‏ (C19-C21),‏ (C23)
(I2) start_indices טינסור מסוג מספר שלם (C2-C3),‏ (C14),‏ (C17),‏ (C22)
(I3) offset_dims קבוע טינסור חד-מימדי מסוג si64 (C1),‏ (C4-C5),‏ (C22)
(I4) collapsed_slice_dims קבוע טינסור חד-מימדי מסוג si64 (C1),‏ (C6-C9),‏ (C22)
(I5) operand_batching_dims קבוע טינסור חד-מימדי מסוג si64 (C1), (C6), (C10-C12), (C16-C18) (C22)
(I6) start_indices_batching_dims קבוע מפריד חד-ממדי מסוג si64 (C13-C17)
(I7) start_index_map קבוע מפריד חד-ממדי מסוג si64 (C3),‏ (C18-C19)
(I8) index_vector_dim קבוע מסוג si64 (C2-C3),‏ (C15),‏ (C22)
(I9) slice_sizes קבוע טינסור חד-מימדי מסוג si64 (C9),‏ (C12),‏ (C20-C22)
(I10) indices_are_sorted קבוע מסוג i1

פלט

שם סוג מגבלות
result טנזור או לכל טנזור קוונטי (C5),‏ (C22-C23)

אילוצים

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) כאשר:
    • batch_dim_sizes = shape(start_indices) למעט העובדה שגודל המימד start_indices שתואם ל-index_vector_dim לא נכלל.
    • offset_dim_sizes = slice_sizes, מלבד העובדה שגודל המאפיינים ב-slice_sizes שתואם ל-collapsed_slice_dims ול-operand_batching_dims לא נכלל.
    • הפונקציה combine ממפה את batch_dim_sizes לצירים התואמים ל-batch_dims ואת offset_dim_sizes לצירים התואמים ל-offset_dims.
  • (C23) element_type(operand) = element_type(result).

דוגמאות

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 דוגמאות נוספות

get_dimension_size

סמנטיקה

הפונקציה מחזירה את הגודל של dimension שצוין ב-operand. באופן רשמי יותר, result = dim(operand, dimension). הסמנטיקה מתייחסת רק לרכיב הצורה של הסוג. סוג הרכיב יכול להיות כל דבר.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1)
(I2) dimension קבוע מסוג si64 (C1)

פלט

שם סוג
result טינסור של מימד 0 מסוג si32

אילוצים

  • (C1) 0 <= dimension < rank(operand).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

דוגמאות נוספות

get_tuple_element

סמנטיקה

מחלצת רכיב במיקום index של ה-tuple של operand ומפיקה result. יותר רשמי: result = operand[index].

קלט

תווית שם סוג מגבלות
(I1) operand קבוצה (C1),‏ (C2)
(I2) index קבוע מסוג si32 (C1), (C2)

פלט

שם סוג אילוצים
result כל סוג נתמך (C2)

מגבלות

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

דוגמאות

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

 דוגמאות נוספות

אם

סמנטיקה

הפונקציה יוצרת את הפלט מהרצה של פונקציה אחת בלבד מ-true_branch או מ-false_branch, בהתאם לערך של pred. יותר רשמי: result = pred ? true_branch() : false_branch().

קלט

תווית שם סוג מגבלות
(I1) pred טינסור של מימד 0 מסוג i1
(I2) true_branch פונקציה (C1-C3)
(I3) false_branch פונקציה (C1),‏ (C2)

פלט

שם סוג אילוצים
results מספר משתנה של טנזורים, טנזורים כמותיים או אסימונים (C3)

אילוצים

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

דוגמאות

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

דוגמאות נוספות

imag

סמנטיקה

הפונקציה מחלצת את החלק המדומה, לפי רכיבים, מ-operand ויוצרת טינסור result. באופן רשמי יותר, לכל רכיב x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או מסוג מורכב (C1), (C2)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) (C1),‏ (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) מוגדר בתור:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • אחרת, element_type(operand).

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

 דוגמאות נוספות

Infeed

סמנטיקה

קריאת נתונים מה-infeed וייצור של results.

הסמנטיקה של infeed_config מוגדרת בהתאם להטמעה.

results מורכב מערכי טען שימושי שמופיעים קודם ומאסימון שמופיע בסוף. בעתיד, אנחנו מתכננים לפצל את המטען הייעודי (payload) ואת האסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג
(I1) token token
(I2) infeed_config קבוע מסוג string

פלט

שם סוג אילוצים
results מספר משתנה של טנזורים, טנזורים כמותיים או אסימונים (C1-C3)

אילוצים

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) או is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

דוגמאות

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 דוגמאות נוספות

iota

סמנטיקה

מילוי טינסור output בערכים בסדר עולה, החל מאפס, לאורך המאפיין iota_dimension. יותר רשמי,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

קלט

תווית שם סוג מגבלות
(I1) iota_dimension si64 (C1)

פלט

שם סוג אילוצים
output טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) 0 <= iota_dimension < rank(output).

דוגמאות

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 דוגמאות נוספות

is_finite

סמנטיקה

הפונקציה מבצעת בדיקה של כל רכיב כדי לבדוק אם הערך ב-x סופי (כלומר, הוא לא +Inf,‏ -Inf או NaN) ויוצרת טינסור y. הטמעה של הפעולה isFinite מהמפרט IEEE-754. בסוגים מרובעניים, התוצאה תמיד true.

קלט

תווית שם סוג מגבלות
(I1) x טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
y טינסור מסוג בוליאני (C1)

אילוצים

  • (C1) shape(x) = shape(y).

דוגמאות

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

דוגמאות נוספות

log

סמנטיקה

מבצעת פעולת לוגריתם ברמת הרכיבים ב-Tensor operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • לצפים: log מ-IEEE-754.
  • למספרים מרוכבים: לוגריתם מרוכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(log, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 דוגמאות נוספות

log_plus_one

סמנטיקה

מבצעת לוגריתם ברמת הרכיבים ופעולה אחת על טנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • לצפים: logp1 מ-IEEE-754.
  • למספרים מורכבים: לוגריתם מורכב ועוד אחד.
  • לסוגים מרושתים: dequantize_op_quantize(log_plus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

דוגמאות נוספות

לוגיסטית

סמנטיקה

ביצוע פעולה לוגיסטית לפי רכיבים בטנסור operand וייצור טנסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • למספרים מסוג float: division(1, addition(1, exp(-x))) מ-IEEE-754.
  • למספרים מרוכבים: פונקציה לוגיסטית מורכבת.
  • לסוגים מרושתים: dequantize_op_quantize(logistic, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 דוגמאות נוספות

מפה

סמנטיקה

הפונקציה מחילה פונקציית מפה computation על inputs לאורך dimensions ומייצרת טינסור result.

באופן רשמי יותר, result[result_index] = computation(inputs...[result_index]).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1-C4)
(I2) dimensions קבוע מפריד חד-ממדי מסוג si64 (C3)
(I3) computation פונקציה (C4)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C4)

אילוצים

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) השדה computation הוא מסוג (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> כאשר Ei = element_type(inputs[i]) ו-E' = element_type(result).

דוגמאות

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

 דוגמאות נוספות

מקסימום

סמנטיקה

מבצעת פעולה מקסימלית מבחינת הרכיבים על הטנזור lhs ו-rhs, ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למשתנים בוליאניים: 'או' לוגי.
  • למספרים שלמים: מספר שלם מקסימלי.
  • למספרים מסוג float: maximum מ-IEEE-754.
  • למספרים מורכבים: הערך המילולי המקסימלי של הצמד (real, imaginary). הטמעת סדר במספרים מורכבים כרוכה בסמינטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מורכבים לפעולה הזו (#560).
  • לסוגים בכמות גדולה:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור או לכל טנזור קוונטי (C1)
(I2) rhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

 דוגמאות נוספות

מינימום

סמנטיקה

הפונקציה מבצעת פעולת min לפי רכיבים בטנסורים lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למשתנים בוליאניים: 'וגם' בלוגיקה.
  • למספרים שלמים: מספר שלם מינימלי.
  • למספרים מסוג float: minimum מ-IEEE-754.
  • למספרים מרוכבים: ערך מינימום לקסיקוגרפי של הצמד (real, imaginary). הטמעת סדר במספרים מורכבים כרוכה בסמינטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מורכבים לפעולה הזו (#560).
  • לסוגים מרובעניים:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור או לכל טנזור קוונטי (C1)
(I2) rhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

 דוגמאות נוספות

כפל

סמנטיקה

מבצעת מכפלה ברמת הרכיב של שני טנסטורים lhs ו-rhs ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למשתנים בוליאניים: 'וגם' בלוגיקה.
  • למספרים שלמים: כפל של מספרים שלמים.
  • למספרים מסוג float: multiplication מ-IEEE-754.
  • למספרים מרוכבים: כפל מורכב.
  • לסוגים מרובעניים:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs טנזור או לכל טנזור קוונטי (C1)
(I2) rhs טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

 דוגמאות נוספות

החלפת חיובי/שלילי

סמנטיקה

מבצע השלילה מבחינת הרכיבים של operand Tensor ויוצר טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים שלמים עם סימן: ביטול של מספר שלם.
  • למספרים שלמים לא חתומים: bitcast למספר שלם חתום, השמטה של מספרים שלמים, ו-bitcast בחזרה למספר שלם לא חתום.
  • למספרים מסוג float: negate מ-IEEE-754.
  • למספרים מרוכבים: שלילה מרוכבת.
  • לסוגים מרושתים: dequantize_op_quantize(negate, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, t e n s o r f l o w, או t e n s o r f l o w, (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

 דוגמאות נוספות

לא

סמנטיקה

הפונקציה מבצעת פעולת NOT של רכיבי הטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • עבור ערכים בוליאניים: NOT לוגי.
  • למספרים שלמים: NOT ברמת הסיביות.

ארגומנטים

שם סוג אילוצים
operand טינסור מסוג בוליאני או שלם (C1)

פלט

שם סוג אילוצים
result טינסור מסוג בוליאני או שלם (C1)

אילוצים

  • (C1) type(operand) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

 דוגמאות נוספות

optimization_barrier

סמנטיקה

מוודא שהפעולות שמניבות את operand מבוצעות לפני כל פעולה שמסתמכת על result, ומונע טרנספורמציות של המהדר מלהעביר פעולות מעבר למחסום. חוץ מזה, הפעולה היא זהות, כלומר result = operand.

ארגומנטים

שם סוג אילוצים
operand מספר משתנה של טנסטורים, טנזורים כמותיים לכל טנזור או אסימונים (C1)

פלט

שם סוג אילוצים
result מספר משתנה של טנסטורים, טנזורים כמותיים לכל טנזור או אסימונים (C1)

אילוצים

  • (C1) type(operand...) = type(result...).

דוגמאות

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

 דוגמאות נוספות

או

סמנטיקה

הפונקציה מבצעת או על כל רכיב של שני הטנסורים lhs ו-rhs ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למשתנים בוליאניים: 'או' לוגי.
  • למספרים שלמים: OR ברמת הביטים.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג מספר שלם או בוליאני (C1)
(I2) rhs טינסור מסוג מספר שלם או בוליאני (C1)

פלט

שם סוג אילוצים
result tensor של מספר שלם או סוג בוליאני (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

 דוגמאות נוספות

פלט

סמנטיקה

הקוד כותב את inputs ב-outfeed ויוצר אסימון result.

הסמנטיקה של outfeed_config מוגדרת בהתאם להטמעה.

קלט

תווית שם סוג
(I1) inputs מספר משתנה של טינסורים או טינסורים מרובים
(I2) token token
(I3) outfeed_config קבוע מסוג string

פלט

שם סוג
result token

דוגמאות

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

 דוגמאות נוספות

כרית

סמנטיקה

הרחבה של operand על ידי הוספת padding_value לאורך הטנזור וגם בין הרכיבים שלו.

edge_padding_low ו-edge_padding_high מציינים את כמות המילוי שנוספה בקצוות הנמוכים (לצד המדד 0) ובקצוות הגבוהים (לצד המדד הגבוה ביותר) של כל מאפיין, בהתאמה. כמות המילוי יכולה להיות שלילית, כאשר הערך המוחלט של המילוי השלילי מציין את מספר הרכיבים שצריך להסיר מהמאפיין שצוין.

interior_padding מציין את המרווח הפנימי שנוסף בין שני רכיבים בכל מימד, שלא יכול להיות שלילי. הוספת מרווח פנימי מתבצעת לפני הוספת מרווח קצה, כך שהוספת מרווח קצה שלילי תסיר רכיבים מהאופרטנד עם הוספת מרווח פנימי.

באופן רשמי יותר, result[result_index] מוגדר בתור:

  • operand[operand_index] אם result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • אחרת, padding_value.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C2),‏ (C4)
(I2) padding_value טינסור של מימד 0 או טינסור מקודד לכל טינסור (C1)
(I3) edge_padding_low קבוע טינסור חד-מימדי מסוג si64 (C1), (C4)
(I4) edge_padding_high קבוע מפריד חד-ממדי מסוג si64 (C1), (C4)
(I5) interior_padding קבוע טינסור חד-מימדי מסוג si64 (C2-C4)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C3-C6)

אילוצים

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

דוגמאות

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 דוגמאות נוספות

partition_id

סמנטיקה

הפונקציה יוצרת את partition_id של התהליך הנוכחי.

פלט

שם סוג
result טינסור של מימד 0 מסוג ui32

דוגמאות

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

 דוגמאות נוספות

popcnt

סמנטיקה

הפונקציה מבצעת ספירה של מספר הביטים שמוגדרים בטנסור operand לפי רכיבים, ויוצרת טנסור result.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טינסור מסוג מספר שלם (C1)

אילוצים

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

 דוגמאות נוספות

כוח

סמנטיקה

מבצעת הערכה ברמת הרכיב של lhs טנזור rhs ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים שלמים: העלאה בחזקה של מספר שלם.
  • למספרים מסוג float: pow מ-IEEE-754.
  • למספרים מרוכבים: העלאה בחזקה מורכבת.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(power, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o r f l o w, או t e n s o r f l o w, (C1)
(I2) rhs טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 דוגמאות נוספות

ריאל

סמנטיקה

מחלצת את החלק האמיתי, מבחינת היסודות, מה-operand ומפיקה טנזור result. באופן רשמי יותר, לכל רכיב x: real(x) = is_complex(x) ? real_part(x) : x.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או מסוג מורכב (C1), (C2)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) (C1),‏ (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) מוגדר בתור:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • אחרת, element_type(operand).

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

 דוגמאות נוספות

recv

סמנטיקה

מקבל נתונים מערוץ עם channel_id ומפיק results.

אם הערך של is_host_transfer הוא true, הפעולה מעבירה נתונים מהמארח. אחרת, הוא מעביר נתונים ממכשיר אחר. המשמעות של הדבר היא שהיא מוגדרת בהטמעה. הדגל הזה מכיל את אותו המידע שמופיע בשדה channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (#666).

results מורכב מערכי טען שימושי שמופיעים קודם ומאסימון שמופיע בסוף. בעתיד אנחנו מתכננים לפצל את עומס העבודה ואת האסימון לשני משתני פלט נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג מגבלות
(I1) token token (C4)
(I2) channel_id קבוע מסוג si64
(I3) channel_type enum של DEVICE_TO_DEVICE ו-HOST_TO_DEVICE (C1)
(I4) is_host_transfer קבוע מסוג i1 (C1)

פלט

שם סוג אילוצים
results מספר משתנה של טינסורים, טינסורים מרובים או אסימונים (C2-C4)

אילוצים

  • (C1) channel_type מוגדר בתור:
    • HOST_TO_DEVICE אם is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) או is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

דוגמאות

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

 דוגמאות נוספות

צמצום

סמנטיקה

הפונקציה מחילה פונקציית הפחתה body על inputs ו-init_values לאורך dimensions ויוצרת טינסורים של results.

סדר הפחתות מוגדר בהתאם להטמעה, כלומר body ו-init_values חייבים ליצור מונויד כדי להבטיח שהפעולה תניב את אותם תוצאות לכל הקלט בכל ההטמעות. עם זאת, התנאי הזה לא תקף להרבה הנחות פופולריות. לדוגמה, הוספה של נקודה צפה עבור body ואפס עבור init_values לא יוצרת למעשה מונויד כי הוספה של נקודה צפה היא לא אסוסיטיבית.

באופן רשמי יותר, results...[j0, ..., jR-1] = reduce(input_slices_converted) כאשר:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], שבו : מוכנסים ב-dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) לעץ בינארי כלשהו schedule כאשר:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי מלא שמוגדר בהטמעה, והטראברסיה בסדר שלו מורכבת מ:
    • ערכים של input_slices_converted...[index], לכל index ב-index_space(input_slices_converted) בסדר האלפביתי המילולי (לפי אותיות קטנות) של index.
    • משולבות עם כמות init_values_converted שהוגדרה על ידי ההטמעה במיקומים שהוגדרו על ידי ההטמעה.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טנסטורים או טנזורים כמותיים לכל טנזור (C1-C4),‏ (C6),‏ (C7)
(I2) init_values מספר משתנה של טינסורים ב-0 מימדים או טינסורים מקוטנים לכל טינסור (C2),‏ (C3)
(I3) dimensions קבוע טינסור חד-מימדי מסוג si64 (C4),‏ (C5),‏ (C7)
(I4) body פונקציה (C6)

פלט

שם סוג אילוצים
results מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C3),‏ (C7),‏ (C8)

אילוצים

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body הוא מסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), מלבד העובדה שגודל המאפיינים של inputs... שתואם ל-dimensions לא נכלל.
  • (C8) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

דוגמאות נוספות

reduce_precision

סמנטיקה

הפונקציה מבצעת המרה של operand לפי רכיבים לסוג אחר של נקודה צפה שמשתמש ב-exponent_bits וב-mantissa_bits, ואז חזרה לסוג הנקודה הצפה המקורי, ויוצרת טינסור output.

באופן רשמי יותר:

  • הביטים של המנטיסה של הערך המקורי מתעדכנים כדי לעגל את הערך המקורי לערך הקרוב ביותר שאפשר לייצג באמצעות mantissa_bits באמצעות סמנטיקה של roundToIntegralTiesToEven.
  • לאחר מכן, אם mantissa_bits קטן ממספר הביטים של המאנטיסה מהערך המקורי, ביטים של מנטיסה ייחתכו ל-mantissa_bits.
  • לאחר מכן, אם הביטים של המעריך של תוצאת הביניים לא מתאימים לטווח שסופק על ידי exponent_bits, תוצאת הביניים גולשת לאינסוף באמצעות הסימן המקורי או עוברת על הערכים מתחת לאפס באמצעות הסימן המקורי.
  • לסוגי נתונים שמסומנים בקונטיינרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)
(I2) exponent_bits קבוע מסוג si32 (C2)
(I3) mantissa_bits קבוע מסוג si32 (C3)

פלט

שם סוג אילוצים
output טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

דוגמאות

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 דוגמאות נוספות

reduce_scatter

סמנטיקה

reduce_scatter

בכל קבוצת תהליכים ברשת התהליכים של StableHLO, מתבצעת הפחתה באמצעות computations על הערכים של הטנזור operand מכל תהליך, מחלקים את תוצאת ההפחתה לפי scatter_dimension לחלקים ומפזרים את החלקים המפוצלים בין התהליכים כדי ליצור את result.

הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] לכל sender ב-process_group, כאשר receiver_index = process_group.index(receiver).

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או לכל טנזור קוונטי (C1),‏ (C2),‏ (C7),‏ (C8)
(I2) scatter_dimension קבוע מסוג si64 (C1), (C2), (C8)
(I3) replica_groups קבוע טינסור דו-מימדי מסוג si64 (C3-C5)
(I4) channel_id קבוע מסוג si64 (C6)
(I5) use_global_device_ids קבוע מסוג i1 (C6)
(I6) computation פונקציה (C7)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C8-C9)

אילוצים

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) מוגדר בתור:
    • num_replicas אם משתמשים ב-cross_replica.
    • num_replicas אם משתמשים ב-cross_replica_and_partition.
    • num_processes אם משתמשים ב-flattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) אם use_global_device_ids = true, אז channel_id > 0.
  • (C7) computation הוא מסוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) מלבד:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

דוגמאות נוספות

reduce_window

סמנטיקה

הפונקציה מחילה פונקציית הפחתה body על חלונות של inputs ו-init_values ומפיקה את הערך results.

בתרשים הבא מוצגת דוגמה קונקרטית לחישוב הרכיבים ב-results... מ-inputs....

reduce_window

באופן רשמי יותר, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (ראו reduce) כאשר:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1-C4),‏ (C6),‏ (C8),‏ (C10),‏ (C12),‏ (C13),‏ (C15)
(I2) init_values מספר משתנה של טינסורים ב-0 מימדים או טינסורים מקוטנים לכל טינסור (C1),‏ (C13)
(I3) window_dimensions קבוע מפריד חד-ממדי מסוג si64 (C4),‏ (C5),‏ (C15)
(I4) window_strides קבוע טינסור חד-מימדי מסוג si64 (C6), (C7), (C15)
(I5) base_dilations קבוע מפריד חד-ממדי מסוג si64 (C8),‏ (C9),‏ (C15)
(I6) window_dilations קבוע טינסור חד-מימדי מסוג si64 (C10),‏ (C11),‏ (C15)
(I7) padding קבוע טינסור דו-מימדי מסוג si64 (C12),‏ (C15)
(I8) body פונקציה (C13)

פלט

שם סוג אילוצים
results מספר משתנה של טנסטורים או טנזורים כמותיים לכל טנזור (C1),‏ (C14-C16)

מגבלות

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body מסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows כאשר:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

 דוגמאות נוספות

היתרה

סמנטיקה

פונקציה זו מבצעת את שארית הרכיבים של המחלק lhs והמחלק rhs כדי ליצור טנזור result.

באופן רשמי יותר, סימן התוצאה נלקח מהמחלק, והערך המוחלט של התוצאה תמיד קטן מהערך המוחלט של המחלק. השארית מחושבת כ-lhs - d * rhs, כאשר d מחושב לפי:

  • למספרים שלמים: stablehlo.divide(lhs, rhs).
  • למספרים שגופים: division(lhs, rhs) מ-IEEE-754 עם מאפיין עיגול roundTowardZero.
  • למספרים מרוכבים: TBD (#997).
  • לסוגים מרובעניים:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

לסוגים של רכיבים של נקודה צפה, הפעולה הזו שונה מהפעולה remainder במפרט IEEE-754, שבה d הוא ערך שלם הקרוב ביותר לערך המדויק של lhs/rhs, עם עדיפות לערכים זוגיים.

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o r f l o w, או t tensor quantor, (C1)
(I2) rhs טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

דוגמאות נוספות

replica_id

סמנטיקה

הפונקציה יוצרת את replica_id של התהליך הנוכחי.

פלט

שם סוג
result טינסור של מימד 0 מסוג ui32

דוגמאות

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

 דוגמאות נוספות

עיצוב מחדש

סמנטיקה

ביצוע שינוי צורה של הטנזור operand לטנזור result. באופן קונספטואלי, זהו שינוי של הצורה של הייצוג הקנוני, למשל מ-tensor<2x3xf32> ל-tensor<3x2xf32> או ל-tensor<6xf32>.

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר ל-result_index ול-operand_index יש את אותו המיקום בסדר האלפביתי של index_space(result) ו-index_space(operand).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C3)

פלט

שם סוג אילוצים
result t e n s o r f l o w, או t e n s o r f l o w, (C1-C3)

אילוצים

  • (C1) הערך של element_type(result) מחושב לפי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand), אלא ש-quantization_dimension(operand) ו-quantization_dimension(result) עשויים להיות שונים.
  • (C2) size(operand) = size(result).
  • (C3) אם is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

 דוגמאות נוספות

הפוך

סמנטיקה

הופכת את הסדר של הרכיבים ב-operand לאורך dimensions שצוין, ויוצרת טינסור result. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 אם d ב-dimensions.
  • operand_index[d] = result_index[d] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1), (C3)
(I2) dimensions קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C3)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C3)

מגבלות

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

דוגמאות

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

 דוגמאות נוספות

rng

סמנטיקה

יוצרת מספרים אקראיים באמצעות האלגוריתם rng_distribution ומפיקה טנזור result של צורה נתונה shape.

אם rng_distribution = UNIFORM, המספרים האקראיים נוצרים לפי התפלגות אחידה במרווח [a, b). אם a >= b, ההתנהגות לא מוגדרת.

אם rng_distribution = NORMAL, המספרים האקראיים נוצרים לפי ההתפלגות הנורמלית עם ממוצע = a וסטיית תקן = b. אם b < 0, ההתנהגות לא מוגדרת.

האופן המדויק שבו נוצרים מספרים אקראיים מוגדר בהטמעה. לדוגמה, יכול להיות שהן יהיו גורמיות או לא גורמיות, ויכול להיות שהן ישתמשו במצבים מוסתרים או לא.

בשיחות עם בעלי עניין רבים, התברר שהפעולה הזו לא בשימוש, ולכן בעתיד אנחנו מתכננים לבדוק את האפשרות להסיר אותה (#597).

קלט

תווית שם סוג מגבלות
(I1) a טינסור של 0 מימדים מסוג שלם, בוליאני או נקודה צפה (floating-point) (C1),‏ (C2)
(I2) b טנזור 0-ממדי של מספר שלם, בוליאני או סוג נקודה צפה (floating-point) (C1),‏ (C2)
(I3) shape קבוע טינסור חד-מימדי מסוג si64 (C3)
(I4) rng_distribution enum של UNIFORM ו-NORMAL (C2)

פלט

שם סוג אילוצים
result Tensor של מספר שלם, בוליאני או סוג נקודה צפה (floating-point) (C1-C3)

אילוצים

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) אם rng_distribution = NORMAL, אז is_float(a).
  • (C3) shape(result) = shape.

דוגמאות

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

סמנטיקה

הפונקציה מחזירה ערך output שמלא ביטים אקראיים אחידים ומצב פלט מעודכן output_state באמצעות האלגוריתם המדומה של מחולל המספרים rng_algorithm בהינתן המצב הראשוני initial_state. מובטח שהפלט יהיה פונקציה determinstic של initial_state, אבל לא מובטח שהוא יהיה determinstic בין הטמעות.

rng_algorithm הוא אחד מהבאים:

  • DEFAULT: אלגוריתם שמוגדר בהטמעה.
  • THREE_FRY: וריאנט של אלגוריתם Threefry שמוגדר בהטמעה.*
  • PHILOX: וריאנט של אלגוריתם Philox שמוגדר בהטמעה.*

* ראו: Salmon et al. SC 2011. מספרים אקראיים מקבילים: החל מ-1, 2, 3.

קלט

תווית שם סוג מגבלות
(I1) rng_algorithm טיפוסים בני מנייה (enum) של DEFAULT, THREE_FRY ו-PHILOX (C2)
(I2) initial_state טינסור חד-מימדי מסוג ui64 (C1),‏ (C2)

פלט

שם סוג מגבלות
output_state זווית חד-ממדית אחת מסוג ui64 (C1)
output Tensor של מספר שלם או מסוג נקודה צפה (floating-point)

אילוצים

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) מוגדר כך:
    • מוגדרת בהטמעה אם rng_algorithm = DEFAULT.
    • 2 אם rng_algorithm = THREE_FRY.
    • 2 או 3 אם rng_algorithm = PHILOX.

דוגמאות

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

סמנטיקה

הפונקציה מבצעת עיגול של כל רכיב בכיוון המספר השלם הקרוב ביותר, ומפיקה טינסור result. הטמעה של הפעולה roundToIntegralTiesToAway מהמפרט IEEE-754. לסוגים שמקובצים, הפונקציה מבצעת את הפעולה dequantize_op_quantize(round_nearest_afz, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

דוגמאות נוספות

round_nearest_even

סמנטיקה

הפונקציה מבצעת עיגול של כל רכיב בכיוון המספר השלם הקרוב ביותר, ומפרקת שווים לכיוון המספר השלם הזוגי, בטנסור operand ומפיקה טנסור result. הטמעה של הפעולה roundToIntegralTiesToEven מהמפרט של IEEE-754. לסוגים כמותיים, הפונקציה מחזירה dequantize_op_quantize(round_nearest_even, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 דוגמאות נוספות

rsqrt

סמנטיקה

הפונקציה מבצעת פעולת שורש ריבועי הפוך לפי רכיבים בטנסור operand ומייצרת טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: rSqrt מ-IEEE-754.
  • למספרים מרוכבים: שורש ריבועי מורכב הפוך.
  • לסוגים הבאים לפי כמות: dequantize_op_quantize(rsqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 דוגמאות נוספות

פיזור

סמנטיקה

מפיקה כל טנזור results ששווים לטינוטורים של inputs, חוץ מזה שמקטעים מסוימים שצוינו על ידי scatter_indices מתעדכנים בערכים updates באמצעות update_computation.

בתרשים הבא אפשר לראות איך אלמנטים ב-updates... ממפים על אלמנטים ב-results... באמצעות דוגמה קונקרטית. בתרשים מוצגות כמה דוגמאות למדדי updates..., ומוסבר בפירוט לאילו מדדי results... הם תואמים.

פיזור

באופן רשמי יותר, לכל update_index ב-index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index מוגדר כך:
    • scatter_indices[si0, ..., :, ..., siN] כאשר si הם רכיבים נפרדים ב-update_scatter_index ו-: מוכנס באינדקס index_vector_dim, אם index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] אחרת.
  • ל-d_input ב-axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] אם d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 אחרת.
  • עבור d_input ב-axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] אם d_input = input_batching_dims[i_batching] וגם d_start = scatter_indices_batching_dims[i_batching].
    • אחרת, full_batching_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] כאשר wi הם רכיבים נפרדים ב-update_window_index, ו-0 מוחדר באינדיקטורים מ-inserted_window_dims ומ-input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

בהתאם לכך, results = exec(schedule, inputs), כאשר:

  • schedule הוא תמורה מוגדרת על ידי ההטמעה של index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) כאשר:
    • אם הערך של result_index נמצא בטווח של shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results הוא עותק של results שבו results...[result_index] מוגדר ל-updated_values....
    • אחרת
    • updated_results = results.
  • exec([], results) = results.

אם הערך של indices_are_sorted הוא true, אז ההטמעה יכולה להניח שהערכים scatter_indices ממוינים ביחס ל-scatter_dims_to_operand_dims, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).

אם הערך של unique_indices הוא true, אז ההטמעה יכולה להניח שכל המדדים של result_index שמפוזרים הם ייחודיים. אם הערך של unique_indices הוא true אבל האינדקסים שאליהם מתבצעת הפצת הנתונים לא ייחודיים, ההתנהגות לא מוגדרת.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1),‏ (C2),‏ (C4-C6),‏ (C11),‏ (C13),‏ (C18),‏ (C21),‏ (C23-C24)
(I2) scatter_indices טינסור מסוג מספר שלם (C4), (C15), (C19), (C22)
(I3) updates מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C3-C6),‏ (C8)
(I4) update_window_dims קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C4),‏ (C7-C8)
(I5) inserted_window_dims קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C4),‏ (C9-C11)
(I6) input_batching_dims קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C4),‏ (C9),‏ (C12-13),‏ (C17-18),‏ (C20)
(I7) scatter_indices_batching_dims קבוע טינסור חד-מימדי מסוג si64 (C14-C18)
(I8) scatter_dims_to_operand_dims קבוע טינסור חד-מימדי מסוג si64 (C19-C21)
(I9) index_vector_dim קבוע מסוג si64 (C4),‏ (C16),‏ (C19),‏ (C22)
(I10) indices_are_sorted קבוע מסוג i1
(I11) unique_indices קבוע מסוג i1
(I12) update_computation פונקציה (C23)

פלט

שם סוג אילוצים
results מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C24-C25)

אילוצים

  • (C1) same(shape(inputs...)).
  • (C2) ‏`rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) כאשר:
    • update_scatter_dim_sizes = shape(scatter_indices), מלבד העובדה שגודל המאפיין scatter_indices שתואם ל-index_vector_dim לא נכלל.
    • update_window_dim_sizes <= shape(inputs[0]) מלבד זאת, אם מידות המאפיינים ב-inputs[0] שתואמות ל-inserted_window_dims ול-input_batching_dims לא נכללות.
    • combine מציבה את update_scatter_dim_sizes בצירים שתואמים ל-update_scatter_dims ול-update_window_dim_sizes בצירים התואמים ל-update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) הסוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) של update_computation הוא is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei לכל i ב-[0,N).

דוגמאות

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 דוגמאות נוספות

בחירה

סמנטיקה

הפונקציה יוצרת טינסור result שבו כל רכיב נבחר מהטינסור on_true או on_false על סמך הערך של הרכיב התואם ב-pred. באופן רשמי יותר, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], כאשר pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_select_quantize(pred, on_true, on_false, type(result)).

קלט

תווית שם סוג מגבלות
(I1) pred טינסור מסוג i1 (C1)
(I2) on_true טנזור או לכל טנזור קוונטי (C1-C2)
(I3) on_false טינסור או טינסור מותאם (quantized) לכל טינסור (C2)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C2)

אילוצים

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

דוגמאות

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

 דוגמאות נוספות

select_and_scatter

סמנטיקה

מפזר את הערכים מהטנזור source באמצעות scatter על סמך התוצאה של reduce_window של הטינזור input באמצעות select, ומפיק את הטינזור result.

בתרשים הבא מוצגת דוגמה קונקרטית לחישוב הרכיבים ב-result מ-operand ו-source.

select_and_scatter

באופן רשמי יותר:

  • selected_values = reduce_window_without_init(...) עם ערכי הקלט הבאים:

    • inputs = [operand].
    • window_dimensions, ‏window_strides ו-padding, שנעשה בהם שימוש כפי שהם.
    • base_dilations = windows_dilations = 1.
    • הערך body מוגדר כך:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    כאשר E = element_type(operand) ו-reduce_window_without_init פועלים בדיוק כמו reduce_window, חוץ מכך ש-schedule של reduce הבסיסי (ראו reduce) לא כולל ערכי init. בשלב זה לא צוין מה קורה אם בחלון התואם אין ערכים (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) כאשר:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index אם ל-selected_values[source_index] יש את הרכיב operand מ-operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

קלט

תווית שם סוג מגבלות
(I1) operand טנזור או לכל טנזור קוונטי (C1-C4),‏ (C6),‏ (C8-C11)
(I2) source טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C2)
(I3) init_value טינסור של מימד 0 או טינסור מקודד לכל טינסור (C3)
(I4) window_dimensions קבוע מפריד חד-ממדי מסוג si64 (C2),‏ (C4),‏ (C5)
(I5) window_strides קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C6),‏ (C7)
(I6) padding קבוע טינסור דו-מימדי מסוג si64 (C2), (C8)
(I7) select פונקציה (C9)
(I8) scatter פונקציה (C10)

פלט

שם סוג מגבלות
result טנזור או לכל טנזור קוונטי (C11-C12)

אילוצים

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows כאשר:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select הוא מסוג (tensor<E>, tensor<E>) -> tensor<i1> כאשר E = element_type(operand).
  • (C10) ‏scatter הוא מסוג (tensor<E>, tensor<E>) -> tensor<E> כאשר is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

דוגמאות

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

דוגמאות נוספות

שליחה

סמנטיקה

הפונקציה שולחת את הערך inputs לערוץ channel_id ומפיקה אסימון result.

אם הערך של is_host_transfer הוא true, הפעולה מעבירה נתונים למארח. אחרת, הנתונים מועברים למכשיר אחר. המשמעות של הדבר היא שהיא מוגדרת בהטמעה. הדגל הזה מכיל את אותו המידע שמופיע בשדה channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (#666).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר משתנה של טינסורים או טינסורים מרובים
(I2) token token
(I3) channel_id קבוע מסוג si64
(I4) channel_type enum של DEVICE_TO_DEVICE ו-DEVICE_TO_HOST (C1)
(I5) is_host_transfer קבוע מסוג i1 (C1)

פלט

שם סוג
result token

אילוצים

  • (C1) channel_type מוגדר בתור:
    • DEVICE_TO_HOST אם is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.

דוגמאות

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

 דוגמאות נוספות

shift_left

סמנטיקה

הפונקציה מבצעת פעולת הזזה שמאלה של כל רכיב בטנסור lhs ב-rhs מספר ביטים, ומפיקה טנסור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג מספר שלם (C1)
(I2) rhs טינסור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טינסור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

 דוגמאות נוספות

shift_right_arithmetic

סמנטיקה

מבצעת פעולת אריתמטיקה אריתמטית ימינה ברמת הרכיבים על הטנזור lhs במספר rhs ביטים, ומפיקה טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג מספר שלם (C1)
(I2) rhs טינסור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טינסור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

 דוגמאות נוספות

shift_right_logical

סמנטיקה

הפונקציה מבצעת פעולת הזזה ימינה לוגית לפי רכיבים בטנסור lhs במספר הסיביות rhs, ויוצרת טנסור result.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג מספר שלם (C1)
(I2) rhs טינסור מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result טינסור מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

דוגמאות נוספות

סמל

סמנטיקה

הפונקציה מחזירה את הסימן של operand לפי רכיבים ומפיקה טינסור result. באופן רשמי יותר, לכל אלמנט x, אפשר לבטא את הסמנטיקה באמצעות התחביר של Python באופן הבא:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

לסוגים כמותיים, הפונקציה מחזירה dequantize_op_quantize(sign, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור של מספר שלם עם סימן, מספר נקודה צפה או מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result Tensor של מספר שלם חתום, נקודה צפה (float-point) או סוג מרוכב, או t e n s o r f l o w (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 דוגמאות נוספות

סינוס

סמנטיקה

מבצעת פעולת סינוס ברמת הרכיב בטינזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: sin מ-IEEE-754.
  • למספרים מרוכבים: סינוס מרוכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(sine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

 דוגמאות נוספות

פרוסה (slice)

סמנטיקה

מחלצת פרוסה מה-operand באמצעות אינדקסים מתחילים שמחושבים באופן סטטי ומפיקה טנזור result. start_indices מכיל את אינדקסי ההתחלה של הפלחים לכל מאפיין, limit_indices מכיל את אינדקסי הסיום (לא כולל) של הפלחים לכל מאפיין ו-strides מכיל את הצעדים לכל מאפיין.

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר operand_index = start_indices + result_index * strides.

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מותאם (quantized) לכל טינסור (C1-C3), (C5)
(I2) start_indices קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C3),‏ (C5)
(I3) limit_indices קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C3),‏ (C5)
(I4) strides קבוע טינסור חד-מימדי מסוג si64 (C2),‏ (C4)

פלט

שם סוג מגבלות
result טינסור או טינסור מותאם (quantized) לכל טינסור (C1),‏ (C5)

אילוצים

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

דוגמאות

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 דוגמאות נוספות

מיון

סמנטיקה

הפונקציה ממיינת יחד פרוסות חד-ממדיות של inputs לפי המאפיין dimension, בהתאם ל-comparator, ויוצרת את results.

בניגוד לקלטים דומים בפעולות אחרות, ב-dimension אפשר להשתמש בערכים שליליים, עם הסמנטיקה שמתוארת בהמשך. יכול להיות שבעתיד לא נאפשר זאת מסיבות עקביות (#1377).

אם הערך של is_stable הוא True, המיון יציב, כלומר הסדר היחסי של הרכיבים שנחשבים זהים על ידי המבצע נשמר. במקרה שיש קלט יחיד, שני הרכיבים e1 ו-e2 נחשבים זהים על ידי המבצע אם ורק אם comparator(e1, e2) = comparator(e2, e1) = false. בהמשך מוסבר איך אפשר להכליל את הנוסחה הזו למספר קלטות.

באופן רשמי יותר, לכל result_index ב-index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] כאשר riN הם רכיבים נפרדים ב-result_index, ו-: מוכנס ב-adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • כאשר sort ממיין פרוסה חד-ממדית בסדר לא יורד, מצפה שהארגומנט comparator_together יחזיר true אם הארגומנט בצד שמאל קטן מהארגומנט בצד ימין.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C1-C5)
(I2) dimension קבוע מסוג si64 (C4)
(I3) is_stable קבוע מסוג i1
(I4) comparator פונקציה (C5)

פלט

שם סוג אילוצים
results מספר וריאדי של טינסורים או טינסורים מקוטנים לכל טינסור (C2),‏ (C3)

אילוצים

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, כאשר R = rank(inputs[0]).
  • (C5) השדה comparator הוא מסוג (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, כאשר Ei = element_type(inputs[i]).

דוגמאות

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 דוגמאות נוספות

sqrt

סמנטיקה

מבצעת פעולת שורש ריבועית ברמת הרכיב ב-Tensor operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • לצפים: squareRoot מ-IEEE-754.
  • למספרים מורכבים: שורש ריבועי מורכב.
  • לסוגים הבאים לפי כמות: dequantize_op_quantize(sqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

 דוגמאות נוספות

הפחתה

סמנטיקה

הפונקציה מבצעת חיסור של שני צמדי רכיבים lhs ו-rhs ומפיקה צמדי רכיבים result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים שלמים: חיסור של מספרים שלמים.
  • למספרים מסוג float: subtraction מ-IEEE-754.
  • למספרים מרוכבים: חיסור מרוכב.
  • לסוגים מרובעניים:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o r f l o w, או t e n s o r f l o w, (C1)
(I2) rhs טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג אילוצים
result טינסור מסוג שלם, של נקודה צפה או של מספר מרוכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

 דוגמאות נוספות

tan

סמנטיקה

הפונקציה מבצעת פעולת טנגנט לפי רכיבים בטנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: tan מ-IEEE-754.
  • למספרים מרוכבים: טנגנס מרוכב.
  • לסוגי נתונים מרוסקים: dequantize_op_quantize(tan, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 דוגמאות נוספות

tanh

סמנטיקה

הפונקציה מבצעת פעולת טנגנס היפרבולי לפי רכיבים בטנסור operand ומייצרת טנסור result. בהתאם לסוג הרכיב, מבצע את הפעולות הבאות:

  • למספרים מסוג float: tanh מ-IEEE-754.
  • למספרים מרוכבים: טנגנס היפרבולי מורכב.
  • לסוגים מרובעניים:
    • dequantize_op_quantize(tanh, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

דוגמאות נוספות

לשנות

סמנטיקה

הפונקציה מבצעת תמורה של המאפיינים של הטנזור operand באמצעות permutation ויוצרת טנזור result. באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר result_index[d] = operand_index[permutation[d]].

קלט

תווית שם סוג מגבלות
(I1) operand טינסור או טינסור מצטבר (C1-C4)
(I2) permutation קבוע טינסור חד-מימדי מסוג si64 (C2-C4)

פלט

שם סוג אילוצים
result t e n s o r f l o w, או t e n s o r f l o w, (C1),‏ (C3-C4)

אילוצים

  • (C1) הערך של element_type(result) מחושב לפי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand), אלא ש-quantization_dimension(operand) ו-quantization_dimension(result) עשויים להיות שונים.
  • (C2) permutation היא תמורה של range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) אם is_per_axis_quantized(result), אז quantization_dimension(operand) = permutation(quantization_dimension(result)).

דוגמאות

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 דוגמאות נוספות

triangular_solve

סמנטיקה

פתרון של קבוצות של מערכות של משוואות ליניאריות עם מטריצות של גורמים משולשיים עליונים או תחתונים.

באופן רשמי יותר, בהינתן a ו-b, result[i0, ..., iR-3, :, :] הוא הפתרון של op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] כאשר left_side הוא true או x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] כאשר left_side הוא false, פתרון המשתנה x כאשר op(a) נקבע על ידי transpose_a, שיכול להיות אחד מהערכים הבאים:

  • NO_TRANSPOSE: ביצוע הפעולה באמצעות a כפי שהוא.
  • TRANSPOSE: ביצוע פעולה על הטרספורמציה של a.
  • ADJOINT: ביצוע פעולה על החלפה צמודה של a.

נתוני הקלט נקראים רק מהמשולש התחתון של a. אחרת, lower הוא true או משולש עליון של a. נתוני הפלט מוחזרים באותו משולש, והערכים במשולש השני מוגדרים בהתאם להטמעה.

אם הערך של unit_diagonal הוא true, ההטמעה יכולה להניח שהאלמנטים של האלכסון של a שווים ל-1. אחרת, ההתנהגות לא מוגדרת.

לסוגי נתונים שמפורטים במספרים, הפונקציה מבצעת את הפעולה dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a t e n s o r f l o w, או t e n s o r f l o w, או t e tensor tenor (C1-C3)
(I2) b t e n s o r f l o w, או t e n s o r f l o w, או t e tensor tenor (C1-C4)
(I3) left_side קבוע מסוג i1 (C3)
(I4) lower קבוע מסוג i1
(I5) unit_diagonal קבוע מסוג i1
(I6) transpose_a enum של NO_TRANSPOSE,‏ TRANSPOSE ו-ADJOINT

פלט

שם סוג מגבלות
result טינסור מסוג נקודה צפה או מסוג מורכב, או טינסור מקודד לכל טינסור (C1)

אילוצים

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) הקשר בין shape(a) לבין shape(b) מוגדר כך:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

דוגמאות

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

קבוצה

סמנטיקה

הפונקציה יוצרת קבוצת ערכים (tuple) מסוג result מהערכים val.

קלט

תווית שם סוג מגבלות
(I1) val מספר משתנה של ערכים (C1)

פלט

שם סוג מגבלות
result קבוצה (C1)

אילוצים

  • (C1) result מסוג tuple<E0, ..., EN-1> כאשר Ei = type(val[i]).

דוגמאות

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

 דוגמאות נוספות

uniform_dequantize

סמנטיקה

הפונקציה מבצעת המרה של רכיבי הטנזור המצטבר operand לטנזור של נקודה צפה result בהתאם לפרמטרים של הערך המצטבר operand.

באופן רשמי יותר, result = dequantize(operand).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w (C1),‏ (C2)

פלט

שם סוג אילוצים
result טינסור מסוג נקודה צפה (floating-point) (C1),‏ (C2)

אילוצים

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

דוגמאות

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

סמנטיקה

הפונקציה מבצעת המרה של רכיבים של טינסור של נקודה צפה או טינסור מקודד operand לטינסור מקודד result, בהתאם לפרמטרי הקידוד שמוגדרים לפי סוג result.

יותר רשמי,

  • אם is_float(operand):
    • result = quantize(operand, type(result)).
  • אם is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor מסוג נקודה צפה (floating-point) או סוג כמותי (C1), (C2)

פלט

שם סוג אילוצים
result טינסור מרובע (C1),‏ (C2)

אילוצים

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

דוגמאות

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

בזמן

סמנטיקה

הפונקציה יוצרת את הפלט מהפעלת הפונקציה body 0 פעמים או יותר, בזמן שפונקציית cond יוצרת את הפלט true. באופן רשמי יותר, אפשר לבטא את הסמנטיקה באמצעות תחביר Python באופן הבא:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

ההתנהגות של לולאה אינסופית תיקבע בהמשך (#383).

קלט

תווית שם סוג מגבלות
(I1) operand מספר משתנה של טינסורים, טינסורים מרובים או אסימונים (C1-C3)
(I2) cond פונקציה (C1)
(I3) body פונקציה (C2)

פלט

שם סוג אילוצים
results מספר משתנה של טנזורים, טנזורים כמותיים או אסימונים (C3)

אילוצים

  • (C1) השדה cond הוא מסוג (T0, ..., TN-1) -> tensor<i1>, כאשר Ti = type(operand[i]).
  • (C2) body הוא מסוג (T0, ..., TN-1) -> (T0, ..., TN-1), כאשר Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

דוגמאות

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

דוגמאות נוספות

xor

סמנטיקה

ביצוע XOR של שני טינסורים lhs ו-rhs לפי רכיבים, יצירת טינסור result. בהתאם לסוג הרכיב, הפעולות הבאות:

  • עבור ערכים בוליאניים: XOR לוגי.
  • למספרים שלמים: XOR ברמת הביטים.

קלט

תווית שם סוג מגבלות
(I1) lhs טינסור מסוג בוליאני או שלם (C1)
(I2) rhs טינסור מסוג בוליאני או שלם (C1)

פלט

שם סוג אילוצים
result טינסור מסוג בוליאני או שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

 דוגמאות נוספות

פעולה הדדית בין ניבים

בשלב זה, תוכניות StableHLO בטבע כוללות לפעמים פעולות שלא מוגדרות על ידי StableHLO.

מודול, פונקציה, קריאה וחזרה

ב-StableHLO נעשה שימוש בפעולות MLIR במקור ל-ModuleOp,‏ FuncOp,‏ CallOp ו-ReturnOp. עשינו את זה כדי לשפר את יכולת הפעולה ההדדית עם מכונות MLIR קיימות, כי הרבה מעברים שימושיים נכתבים ומטרגטים את FuncOp ו-ModuleOp, וצינורות עיבוד נתונים רבים של הידור מצפים שהתפעול הזה יהיה קיים. מבטיחים תאימות מלאה לפעולות האלה. אם יהיו שינויים כלשהם בפעולות האלה באופן לא תואם (למשל הסרה), יתווספו להם אנלוגים של StableHLO כדי לשמור על תאימות.

CHLO

קבוצת הפעולות של CHLO מכילה פעולות ברמה גבוהה יותר שמתפרקות ל-StableHLO. בשלב הזה אין התחייבות לתאימות ל-CHLO. כדי להבטיח תאימות, צריך להשתמש במעבר chl-legalize-to-stablehlo לפני הסריאליזציה.

פעולות על צורות

תרחיש לדוגמה נפוץ בקרב חברי הקהילה הוא שימוש בפעולות מסוימות מהדיאלקטים של הליבה של MLIR בתוכניות דינמיות של StableHLO כדי לבצע חישובי צורות. בדרך כלל, אלה כוללות ניב של shape כמו shape_of או num_elements, ניב של tensor, פעולות כמו dim או from_elements, והסוג המובנה index.

ב-Dynamism RFC > O2 מצוין שהם לא נכללים בהיקף, אבל יש תמיכה מסוימת בסוגי index למטרות יכולת פעולה הדדית. אין התחייבות לתאימות של הפונקציות או הסוגים האלה. אפשר להשתמש במעבר shape-legalize-to-stablehlo כדי להמיר את הפעולות האלה לפעולות StableHLO נתמכות במלואן.

פעולות שהוצאו משימוש

יש כמה פעולות StableHLO שעברו בירושה מ-MHLO שהוצאו משימוש ויצאו מ-StableHLO. הפרטים המלאים על ההסרות האלה מפורטים בבקשה מס' 2283 בנושא ניקוי של StableHLO v1.0. מספר הבעיה במעקב אחר ההוצאות משימוש האלה הוא 2340.

הפעולות האלה נכללות בכמה קטגוריות:

  • הקטגוריה 'לא ב-HLO' של פעולות StableHLO – הן היו בהתחלה חלק מקבוצת הפעולות של StableHLO, אבל בהמשך הוחלט שהן לא מתאימות לה: broadcast,‏ create_token,‏ cross-replica-sum,‏ dot,‏ einsum,‏ torch_index_select,‏ unary_einsum (#3).
  • פעולות שלא בשימוש – יכול להיות שהפעולות האלה היו שימושיות בשלב מסוים, אבל הן לא מפותחות מספיק או שצינורות עיבוד הנתונים שמשתמשים בהן עברו שינוי כך שהן לא נדרשות יותר. אלה כוללים את הפונקציות map,‏ tuple (#598),‏ get_tuple_element,‏ rng,‏ complex (#560) וכן פונקציית הגלול window_reversal (#1181).

אפשר להסיר בקלות חלק מהפעולות האלה, כי אפשר להביע אותן באמצעות פעולות קיימות (broadcast, create_token, cross-replica-sum, dot, unary_einsum), והן יוסרו אחרי שחלון התאימות הקיים יסתיים (6 חודשים). אנחנו עדיין בודקים אם להסיר פעולות אחרות (einsum,‏ get_tuple_element,‏ map,‏ rng,‏ torch_index_select,‏ tuple,‏ complex,‏ השוואות,‏ window_reversal). בהתאם למשוב מהקהילה, הפעולות האלה יוסרו או יתווספו למפרט עם תמיכה מלאה. כל עוד לא נודע על האפשרויות האלה, מובטחת תאימות ל-6 חודשים בלבד.

ביצוע

ביצוע ברצף

כדי להריץ תוכנית StableHLO, מספקים ערכי קלט לפונקציה main ומחשבים את ערכי הפלט. כדי לחשב את ערכי הפלט של פונקציה, מפעילים את תרשים הפעולות שמתחיל בפעולה המתאימה של return.

סדר הביצוע מוגדר בהטמעה, כל עוד הוא תואם למעבר הנתונים, כלומר אם הפעולות מבוצעות לפני השימוש בהן. ב-StableHLO, כל הפעולות עם תוצאות לוואי צורכות אסימון אחד ומפיקות אסימון אחד (אפשר למזג כמה אסימונים לאסימון אחד באמצעות after_all), כך שסדר הביצוע של תוצאות הלוואי תואם גם לזרימת הנתונים. לדוגמה, בתוכנית הבאה יש שתי סדרות אפשריות לביצוע: %0%1%2return ו-%1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

באופן רשמי יותר, תהליך StableHLO הוא שילוב של: 1) תוכנית StableHLO, 2) סטטוסים של פעולות (עדיין לא בוצעו, כבר בוצעו) ו-3) ערכים ביניים שהתהליך פועל עליהם. התהליך מתחיל בערכי הקלט של הפונקציה main, עובר דרך התרשים של סטטוסי הפעולות והערכים הביניים של העדכון של צוות התפעול, ומסתיים בערכי הפלט. פורמט רשמי יותר נקבע בהמשך (#484).

ביצוע במקביל

אפשר להריץ תוכניות StableHLO במקביל, כאשר הן מאורגנות בתבנית של רשת תהליכים דו-מימדית של num_replicas על num_partitions, ושני הערכים האלה הם מסוג ui32.

ברשת התהליכים StableHLO, num_replicas * num_partitions של תהליכי StableHLO מבוצעים בו-זמנית. לכל תהליך יש process_id = (replica_id, partition_id) ייחודי, כאשר replica_id ב-replica_ids = range(num_replicas) ו-partition_id ב-partition_ids = range(num_partitions), שניהם מסוג ui32.

גודל רשת התהליכים ידוע באופן סטטי לכל תוכנית (בעתיד אנחנו מתכננים להפוך אותה לחלק מפורש בתוכניות StableHLO #650), והמיקום ברשת התהליכים ידוע באופן סטטי לכל תהליך. לכל תהליך יש גישה למיקום שלו בתצוגת התהליכים באמצעות הפונקציות replica_id ו-partition_id.

בתצוגת התהליך, כל התוכניות יכולות להיות זהות (בסגנון 'תוכנית אחת, נתונים מרובים'), שונות (בסגנון 'תוכניות מרובות, נתונים מרובים') או משהו באמצע. בעתיד אנחנו מתכננים להוסיף תמיכה בדרכים נוספות להגדרת תוכניות StableHLO מקבילות, כולל GSPMD‏ (#619).

בתצוגת התהליכים, התהליכים עצמאיים זה מזה ברוב המקרים – יש להם סטטוסים נפרדים של פעולות, ערכים נפרדים של קלט/שלב ביניים/פלט, ורוב הפעולות מתבצעות בנפרד בין התהליכים, מלבד מספר קטן של פעולות קולקטיביות שמתוארות בהמשך.

מכיוון שבביצוע רוב הפעולות נעשה שימוש רק בערכים מאותו תהליך, בדרך כלל אפשר להפנות לערכים האלה לפי השמות שלהם. עם זאת, כשמתארים את הסמנטיקה של פעולות קולקטיביות, זה לא מספיק, ולכן השתמשו בסימון name@process_id כדי להפנות לערך name בתוך תהליך מסוים. (מנקודת המבט הזו, אפשר להתייחס ל-name ללא הסמכת ‎@‎ כקיצור דרך ל-name@(replica_id(), partition_id())).

סדר הביצוע בתהליכים מוגדר על ידי ההטמעה, מלבד הסנכרון שמתבצע באמצעות תקשורת נקודה לנקודה ופעולות קולקטיביות, כפי שמתואר בהמשך.

תקשורת מנקודה לנקודה

תהליכי StableHLO יכולים לתקשר ביניהם באמצעות ערוצי StableHLO. ערוץ מיוצג על ידי מזהה חיובי מסוג si64. באמצעות פעולות שונות, אפשר לשלוח ערכים לערוצים ולקבל אותם מהם.

אנחנו עדיין לא יודעים איך נציג את המידע הזה באופן רשמי, למשל, מאיפה מגיעים מזהי הערוצים האלה, איך תוכניות העיבוד מקבלות הודעה עליהם ואיזה סוג של סנכרון הן גורמות (#484).

תקשורת בסטרימינג

לכל תהליך StableHLO יש גישה לשני ממשקי סטרימינג:

  • Infeed שאפשר לקרוא מהם.
  • Outfeed שאפשר לכתוב אליו.

בניגוד לערוצים, שמשמשים לתקשורת בין תהליכים ולכן קיימים תהליכים בשני הצדדים, בפידים ובפידים אחרים מוגדר הטמעת קצה אחרת.

טרם נקבע, למשל, איך תקשורת בסטרימינג משפיעה על סדר הביצוע ואיזה סוג של סנכרון נכנס אליה, (#484).

פעולות קבוצתיות

יש שישה אופרטורים קולקטיביים ב-StableHLO: all_gather,‏ all_reduce,‏ all_to_all,‏ collective_broadcast,‏ collective_permute ו-reduce_scatter. כל הפעולות האלה מפצלות את התהליכים ברשת התהליכים של StableHLO לקבוצות תהליכים של StableHLO, ומבצעות חישוב משותף בתוך כל קבוצת תהליכים, בנפרד מקבוצות עיבוד אחרות.

בתוך כל קבוצת תהליכים, פעולות קולקטיביות עשויות ליצור מחסום סנכרון. אנחנו עדיין לא יודעים מתי בדיוק מתרחש הסנכרון הזה, איך התהליכים מגיעים למחסום הזה ומה קורה אם הם לא מגיעים אליו (#484).

אם קבוצת התהליכים כוללת תקשורת בין מחיצות, כלומר יש תהליכים בקבוצת התהליכים שמזהי המחיצות שלהם שונים, אז כדי להריץ את הפעולה הקבוצתית צריך ערוץ, והפעולה הקבוצתית צריכה לספק channel_id חיובי מסוג si64. תקשורת חוצת-עותקים לא צריכה ערוצים.

החישובים שמבצעות הפעולות הקולקטיביות ספציפיים לכל פעולה, כפי שמתואר בסעיפים של הפעולות הספציפיות למעלה. עם זאת, השיטות שבאמצעותן רשת התהליכים מחולקת לקבוצות תהליכים משותפות בין צוותי התפעול האלה, ומתוארות בקטע הזה. יותר רשמי, StableHLO תומך בארבע האסטרטגיות הבאות.

cross_replica

בכל קבוצת תהליכים מתקיימת רק תקשורת בין עותקים. אסטרטגיה זו לוקחת replica_groups - רשימה של רשימות של מזהי רפליקות - ומחשבת מכפלה קרטזית של replica_groups עד partition_ids. replica_groups חייב לכלול רכיבים ייחודיים ולכסות את כל replica_ids. באופן רשמי יותר, באמצעות התחביר של Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, בשביל replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, cross_replica יפיק [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

רק תקשורת בין מחיצות מתרחשת בכל קבוצת תהליכים. האסטרטגיה הזו מקבלת את partition_groups – רשימה של רשימות של מזהי מחיצות – ומחשבת מכפלה קרטוזיאנית של partition_groups ב-replica_ids. partition_groups חייב לכלול רכיבים ייחודיים ולכסות את כל partition_ids. באופן רשמי יותר, שימוש בתחביר Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, עבור partition_groups = [[0, 1]] ו-num_replicas = 4, הפונקציה cross_partition תיצור את הערך [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

יכול להיות שבתוך כל קבוצת תהליכים תתבצע גם תקשורת בין מחיצות וגם תקשורת בין מחיצות. בשיטה הזו, המערכת מקבלת את replica_groups – רשימה של רשימות של מזהי רפליקות – ומחשבת את המוצרים הקרטזיאניים של כל replica_group לפי partition_ids. replica_groups חייב לכלול אלמנטים ייחודיים ויכסה את כל ה-replica_ids. באופן רשמי יותר, באמצעות תחביר Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הפונקציה cross_replica_and_partition תיצור את הערך [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

בשיטה הזו, המערכת מקבלת את flattened_id_groups – רשימה של רשימות של מזהי תהליכים 'שטוחים' בצורת replica_id * num_partitions + partition_id – וממירה אותם למזהי תהליכים. flattened_id_groups חייבים לכלול רכיבים ייחודיים ויכסו את כל process_ids. באופן רשמי יותר, באמצעות תחביר Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]],‏ num_replicas = 4 ו-num_partitions = 2, הפונקציה flattened_ids תיצור את הערך [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

דיוק

בשלב זה, StableHLO לא מספק ערבויות לגבי דיוק מספרי, אבל זה עשוי להשתנות בעתיד (#1156).

סמנטיקה של ביצוע פעולה מרוסקת

המשמעות של פעולות StableHLO כמותיות עשויה להשתנות בהתאם לדרישות וליכולות של החומרה. לדוגמה, חלק מהחומרה עשויה להחליט אם לפרש פעולות כמותיות באמצעות אסטרטגיה של מפענח, ביצוע פעולה בנקודה צפה ולבסוף כימות. משתמשים אחרים יכולים לבצע את החישוב כולו באמצעות חשבונית של מספרים שלמים. כתוצאה מכך, הפרשנות של פעולות StableHLO בקנה מידה נקבע אך ורק על ידי ההטמעה הספציפית. הפרשנות של קידוד היברידי (#1575) צריכה להתבסס על הסמנטיקה שלו כפי שמפורט במפרט (דרך 1792).

שגיאות

תוכניות StableHLO מאומתות באמצעות קבוצה נרחבת של אילוצים לפעולות ספציפיות, שמבטלים סיווגים רבים של שגיאות לפני זמן הריצה. עם זאת, עדיין אפשר להגדיר תנאי שגיאה, למשל חריגה ממספרים שלמים, גישה מחוץ לתחום וכו'. אם לא צוין במפורש, כל השגיאות האלו יגרמו להתנהגות המוגדרת על ידי ההטמעה, אבל הדבר עשוי להשתנות בעתיד (#1157).

חריגים ברמה של נקודה צפה (floating-point)

חריג לכלל הזה הוא שהתנהגות של חריגות של נקודות צפות בתוכניות StableHLO מוגדרת היטב. פעולות שמובילות לחריגות שמוגדרות בתקן IEEE-754 (פעולה לא חוקית, חלוקת אפס, חריגה ממכסת הנתונים, חריגה ממכסת הנתונים הנמוכה או חריגות לא מדויקות) יוצרות תוצאות ברירת מחדל (כפי שמוגדרות בתקן) וממשיכות את הביצוע בלי להפעיל את דגל הסטטוס המתאים, בדומה לטיפול בחריגות מסוג raiseNoFlag מהתקן. חריגים לפעולות לא סטנדרטיות (למשל, אריתמטיקה מורכבת ופונקציות טרנסצנדנטליות מסוימות) מוגדרים בהטמעה.

אי התאמה של צורות

‏StableHLO תומך בטנסורים בעלי צורה דינמית. עם זאת, הצורות צריכות להיות זהות בסביבת זמן הריצה, אחרת ההתנהגות לא מוגדרת. ב-StableHLO לא קיימת אופרטורה מפורשת שיכולה לאמת שללטנזר יש צורה מסוימת בזמן הריצה. היצירה של קוד תקין היא באחריות המפיק.

לדוגמה, התוכנית שבהמשך תקינה. עם זאת, במהלך זמן הריצה, הצורות המדויקות של %arg0 ו-%arg1 צריכות להיות זהות, אחרת ההתנהגות של התוכנית לא מוגדרת:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

תווים

כדי לתאר את התחביר, במסמך הזה נעשה שימוש בטעימה המשופרת של ISO בתחביר EBNF (ISO/IEC 14977:1996,‏ Wikipedia), עם שני שינויים: 1) הכללים מוגדרים באמצעות ::= במקום =,

2) שרשור מוצג באמצעות צירוף ולא באמצעות ,.

כדי לתאר את הסמנטיקה (כלומר בקטעים Types‏, Constants ו-Ops), אנחנו משתמשים בנוסחאות שמבוססות על תחביר Python עם תמיכה בביטוי תמציתי של פעולות מערך, כפי שמתואר בהמשך. האפשרות הזו עובדת טוב בקטעי קוד קטנים, אבל במקרים נדירים שבהם יש צורך בקטעי קוד גדולים יותר, אנחנו משתמשים בתחביר וניל Python, שתמיד מצוין באופן מפורש.

נוסחאות

נלמד איך נוסחאות פועלות על סמך דוגמה מהמפרט dot_general. אחת מהמגבלות של הפעולה הזו נראית כך: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

השמות שמופיעים בנוסחה הזו מגיעים משני מקורות: 1) פונקציות גלובליות, כלומר dim, 2) הגדרות של חברים ברכיב התכנית התואם, כלומר lhs, lhs_batching_dimensions, rhs ו-rhs_batching_dimensions, שהם מקורות הקלט שמוגדרים בקטע Inputs (קלט) של dot_general.

כפי שצוין למעלה, התחביר של הנוסחה הזו מבוסס על Python עם כמה תוספי תמציתיות. כדי להבין את הנוסחה, נבצע בה שינוי ותרגום ל-Python רגיל.

א) בנוסחאות האלה, אנחנו משתמשים ב-= כדי לייצג שוויון, ולכן השלב הראשון בקבלת התחביר של Python הוא החלפת = ב-==, באופן הבא: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

ב) בנוסף, הנוסחאות האלה תומכות באליפסות (...) שממירות ביטויים סקלריים לביטויים של טינסורים. בסקירה תמציתית, המשמעות של f(xs...) היא בערך 'לכל סקלר x במטווח xs, צריך לחשב f(x) סקלרי ואז להחזיר את כל התוצאות הסקלריות האלה יחד כתוצאה של טנזור. בתחביר Python רגיל, הנוסחה לדוגמה הופכת ל-[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

בזכות האליפסות, לרוב אפשר להימנע מעבודה ברמת המשתנים הסקלרים הבודדים. עם זאת, במקרים מסוימים, אפשר להשתמש בתחביר לא רשמי ברמה נמוכה יותר, כמו בנוסחה start_indices[bi0, ..., :, ..., biN] מהמפרט gather. כדי לשמור על תמציתיות, אנחנו לא מספקים פורמליזם מדויק לתרגום תחביר כזה ל-Python רגיל, בתקווה שעדיין יהיה אפשר להבין אותו באופן אינטואיטיבי על בסיס כל מקרה לגופו. אם נוסחאות מסוימות נראות לכם לא ברורות, נשמח לדעת וננסה לשפר אותן.

כמו כן, תבחינו שהנוסחאות משתמשות באליפסות כדי להרחיב כל מיני רשימות, כולל טיינורים, רשימות של טיינורים (שיכולות לנבוע, למשל, ממספר משתנה של טיינורים) וכו'. זהו תחום נוסף שבו אנחנו לא מספקים פורמליזם מדויק (למשל, רשימות הן לא חלק ממערכת הסוגים של StableHLO), אלא מסתמכים על הבנה אינטואיטיבית.

ג) הכלי האחרון שאנחנו משתמשים בו הוא שידור משתמע. אופרטורים של StableHLO לא תומכים בשידור משתמע, אבל הנוסחאות כן תומכות בו, גם כדי לשמור על תמציתיות. בקצרה, אם משתמשים במשתנה סקלר בהקשר שבו צפוי טינסור, המשתנה הסקלר מופץ לצורה הצפויה.

בהמשך לדוגמה של dot_general, הנה אילוץ נוסף: 0 <= lhs_batching_dimensions < rank(lhs). כפי שמוגדר במפרט של dot_general, lhs_batching_dimensions הוא טינסור, אבל גם 0 וגם rank(lhs) הם סקלריים. אחרי שנחיל שידור מרומז, הנוסחה תהיה [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

כשמחילים את הנוסחה הזו על פעולת dot_general מסוימת, הערך של הנוסחה הזו יהיה טנזור של בוליאניים. כשמשתמשים בנוסחאות כאילוצים, האילוץ תקף אם התוצאה של הנוסחה היא true או טינסור שיש בו רק רכיבים מסוג true.

שמות

בנוסחאות, ההיקף המילוני כולל: 1) פונקציות גלובליות, 2) הגדרות של משתמשים,

3) הגדרות מקומיות. בהמשך מופיעה רשימת הפונקציות הגלובליות. רשימת ההגדרות של הרכיבים תלויה ברכיב התוכנית שאליו מוחלת הסימון:

  • בהגדרות של אופרטורים, שמות המשתנים כוללים שמות שהוצגו בקטעים 'קלט' ו'פלט'.
  • בכל שאר המקרים, הגדרות המשתתפים כוללות חלקים מבניים של רכיב התוכנית, שנקראו על שם הסמלים הלא-סופיים התואמים של EBNF. ברוב המקרים, השמות של החלקים המבניים האלה מתקבלים על ידי המרת השמות של הרכיבים הלא-סופיים ל-snake case (למשל, IntegerLiteral => integer_literal), אבל לפעמים השמות מקוצרים בתהליך (למשל, QuantizationStorageType => storage_type). במקרה כזה, השמות מוצגים באופן מפורש, בדומה לקטעים 'קלט' או 'פלט' במפרטי הפעולה.
  • בנוסף, הגדרות החברות תמיד כוללות את self כדי להפנות לרכיב התוכנית המתאים.

ערכים

כשנוסחאות מוערכות, הן פועלות עם סוגי הערכים הבאים: 1) Value (ערכים בפועל, למשל dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; הסוגים שלהם תמיד ידועים), 2) Placeholder (ערכים עתידיים, למשל lhs,‏ rhs או result; הערכים בפועל שלהם עדיין לא ידועים, רק הסוגים שלהם ידועים), 3) Type (סוגים כפי שמוגדרים בקטע 'סוגים'), 4) Function (פונקציות גלובליות כפי שמוגדרות בקטע 'פונקציות').

בהתאם להקשר, השמות עשויים להתייחס לערכים שונים. באופן ספציפי יותר, הקטע 'סמנטיקה' של אופרטורים (ויש מקבילים לרכיבים אחרים בתוכנית) מגדיר את הלוגיקה של זמן הריצה, כך שכל הקלט זמין כ-Value. לעומת זאת, הקטע 'אילוצים' של אופרטורים (וישויות מקבילות) מגדיר לוגיקה של 'זמן הידור', כלומר משהו שמתבצע בדרך כלל לפני זמן הריצה, כך שרק ערכים קבועים זמינים כ-Value, וערכים אחרים זמינים רק כ-Placeholder.

שמות בקטע 'סמנטיקה' בקטע 'מגבלות'
פונקציות גלובליות Function Function
קלט קבוע Value Value
קלט לא קבוע Value Placeholder
פלט Value Placeholder
הגדרות מקומיות תלוי בהגדרה תלוי בהגדרה

נבחן פעולה transpose לדוגמה:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

בפעולה הזו, permutation הוא קבוע, ולכן הוא זמין כ-Value גם בסמינטיקה וגם באילוצים. לעומת זאת, הערכים operand ו-result זמינים כ-Value בסמינטיקה, אבל רק כ-Placeholder באילוצים.

פונקציות

בניית סוגים

אי אפשר להשתמש בפונקציות כדי לבנות סוגים. במקום זאת, אנחנו משתמשים ישירות בתחביר של הטיפוס כי הוא בדרך כלל תמציתי יותר. למשל, (tensor<E>, tensor<E>) -> (tensor<E>) במקום function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

פונקציות על סוגים

  • element_type מוגדר בסוגים של tenors ובסוגים של tenors ותחזירים כמותיים, בהתאמה, TensorElementType או QuantizedTensorElementType של TensorType או QuantizedTensorType המתאימים.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך של is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool בודק אם אפשר לשדרג את הסוג x לסוג y. כשהערך של x ו-y הוא QuantizedTensorElementType, המבצע יחול רק על storage_type. הגרסה הספציפית הזו של קידום המכירות משמשת כרגע בהקשר של חישוב ההנחות (פרטים נוספים זמינים ב-RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך של is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. זמין לכל הסוגים. לדוגמה, is_float(x) מחזירה true אם x הוא FloatType. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור דרך של is_type_name(type(x)).

  • max_value(x: Type) -> Value מחזירה את הערך המקסימלי של TensorElementType. אם x הוא לא TensorElementType, הפונקציה מחזירה את הערך None.

  • min_value(x: Type) -> Value מחזירה את הערך המינימלי האפשרי של TensorElementType. אם x הוא לא TensorElementType, הפונקציה מחזירה את הערך None.

  • member_name(x: Value | Placeholder | Type) -> Any. זמין לכל הגדרות החברים member_name מכל הסוגים. לדוגמה, הפונקציה tensor_element_type(x) מחזירה את החלק TensorElementType של TensorType התואם. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור דרך של member_name(type(x)). אם x הוא לא סוג שכולל איבר מתאים, או ערך או placeholder מסוג כזה, הפונקציה מחזירה את הערך None.

  • הפונקציה is_empty_algorithm(*args: Type) בודקת אם כל שדות האלגוריתם של הנקודות מוגדרים ל-None. הצורך בכך נובע מכך שלאלגוריתמים של נקודות יש התנהגויות ברירת מחדל שהוגדרו בהטמעה, ולכן ציון ערך ברירת מחדל יהיה שגוי.

בניית ערכים

  • operation_name(*xs: Value | Type) -> Value – זמין לכל הפעולות. לדוגמה, הפונקציה add(lhs, rhs) מקבלת שני ערכי טינסור lhs ו-rhs ומחזירה את הפלט של הערכת הפעולה add עם הקלט הזה. בפעולות מסוימות, למשל broadcast_in_dim, סוגי הפלט שלהן הם 'נושאי עומס', כלומר נדרשים להערכת הפעולה. במקרה הזה, הפונקציה משתמשת בסוגים האלה כארגומנטים.

פונקציות על ערכים

  • כל הפונקציות והאופרטורים של Python זמינים. למשל, אפשר ליצור מ-Python גם סימון subscription וגם באמצעות פילוח (slice), וכך אפשר להוסיף לאינדקס tensors, tendors tendors ו-tuples.

  • הפונקציה to_destination_type(x: Value, destination_type: Type) -> Value מוגדרת בטנסורים ומחזירה את הערך המומר של x על סמך type(x) ו-destination_type באופן הבא:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

יש דיון ראשוני על מיזוג של הפעולות convert,‏ uniform_quantize ו-uniform_dequantize (#1576). אחרי המיזוג, לא נצטרך את הפונקציה שלמעלה ואפשר להשתמש בשם הפעולה במקום ב-convert.

  • הפונקציה is_nan(x: Value) -> Value מוגדרת בטנסורים ומחזירה את הערך true אם כל הרכיבים של x הם NaN, או את הערך false במקרים אחרים. אם x הוא לא טינסור, הפונקציה מחזירה את הערך None.

  • הפונקציה is_sorted(x: Value) -> Value מוגדרת בטנסורים ומחזירה את הערך true אם הרכיבים של x ממוינים בסדר עולה לפי הסדר האלפביתי של האינדקסים שלהם, או את הערך false במקרים אחרים. אם x הוא לא טנזור, הפונקציה מחזירה את None.

  • הפונקציה is_unique(x: Value) -> Value מוגדרת בטנסורים ומחזירה את הערך true אם ל-x אין רכיבים כפולים, או את הערך false במקרה אחר. אם x הוא לא טינסור, הפונקציה מחזירה את הערך None.

  • member_name(x: Value) -> Any מוגדר לכל הגדרות המשתתפים member_name של כל הערכים. לדוגמה, real_part(x) מחזירה את החלק RealPart של ComplexConstant תואם. אם x הוא לא ערך שיש לו חבר מתאים, הפונקציה מחזירה את None.

  • same(x: Value) -> Value מוגדר בסכסוכים ומחזיר את הערך true אם האלמנטים של x שווים זה לזה, או false, אחרת. אם לטרנספורמר אין רכיבים, זה נחשב כ'הכול שווה זה לזה', כלומר הפונקציה מחזירה את הערך true. אם x הוא לא טינסור, הפונקציה מחזירה את הערך None.

  • הפונקציה split(x: Value, num_results: Value, axis: Value) -> Value מוגדרת בדגשונים ומחזירה פרוסות num_results של x לאורך הציר axis. אם x הוא לא טינסור או dim(x, axis) % num_results != 0, הפונקציה מחזירה את הערך None.

  • הפונקציה is_defined_in_parent_scope(x: Value) -> Value מוגדרת למחרוזות ומחזירה את הערך true אם x הוא שם של פונקציה שמוגדרת באותו היקף כמו פונקציית ההורה של הפעולה הרלוונטית.

  • הפונקציה is_namespaced_op_name(x: Value) -> Value מוגדרת למחרוזות ומחזירה את הערך true אם x הוא שם אופרטור חוקי, כלומר הוא עומד בביטוי הרגולרי הבא: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

חישובי צורות

  • axes(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value הוא קיצור דרך ל-shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List הוא קיצור דרך של list(map(lambda axis: dim(x, axis), axes)).

  • הפונקציה index_space(x: Value | Placeholder | Type) -> Value מוגדרת בטנסורים ומחזירה את האינדקסים size(x) של TensorType התואם, שממוינים בסדר אלפביתי עולה, כלומר [0, ..., 0],‏ [0, ..., 1],‏ …,‏ shape(x) - 1. אם x הוא לא סוג טינסור, סוג טינסור מקודד, ערך או placeholder של אחד מהסוגים האלה, הפונקציה מחזירה את None.

  • rank(x: Value | Placeholder | Type) -> Value הוא קיצור דרך של size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value מוגדר בקטע 'פונקציות על סוגים' באמצעות member_name.

  • size(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-reduce(lambda x, y: x * y, shape(x)).

חישובי קציבה

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type הוא קיצור דרך ל-element_type(baseline_type(x)).

  • baseline_type מוגדר בסוגים של t e n s o r f l o w, וסוגי t e n s t e n s t o w, וממיר אותם ל- Baseline, כלומר טיפוס עם אותה צורה אבל עם פרמטרי הקוונטיזציה של סוג הרכיב אופסו לערכי ברירת המחדל. זוהי דרך נוחה להשוות בין סוגי טינסורים לבין סוגי טינסורים מרובים באופן אחיד, ויש צורך בכך לעיתים קרובות. בסוגי הערכים האלה, ניתן להשוות בין סוגים שמתעלמים מפרמטרים של קוונטיזציה, כלומר shape, storage_type, expressed_type, storage_min, storage_max ו-quantization_dimension (לסוג כמותית לכל ציר) צריכים להיות זהים, אבל ייתכנו הבדלים בין scales ל-zero points.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • הפונקציה dequantize מוגדרת על סוגי טינסורים שעברו קצירה, והיא הופכת אותם לסוגי טינסורים של נקודה צפה. הפעולה הזו מתבצעת על ידי המרת רכיבים מרושתים שמייצגים ערכים שלמים מסוג האחסון לערכים תואמים של נקודה צפה מסוג הביטוי, באמצעות נקודת האפס והסולם המשויכים לסוג הרכיב המרושת.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • הפונקציה quantize מוגדרת על סוגי טינסורים של נקודה צפה, והיא הופכת אותם לסוגים של טינסורים מקוטעים. זה קורה על ידי המרה של ערכי נקודה צפה (floating-point) מהסוג המבוטא לערכי מספרים שלמים תואמים של סוג האחסון, באמצעות נקודת האפס וקנה המידה שמשויכים לסוג הרכיב הכמותי.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize משמש לציון חישובים לפי רכיבים בטנסורים מקוטעים. הוא מבצע פעולת דקונטיזציה, כלומר הופך רכיבים מקונטיזים לסוגים המפורטים שלהם, מבצע פעולה ואז מבצע פעולת קונטיזציה, כלומר הופך את התוצאות חזרה לסוגים שלהם לאחסון. כרגע הפונקציה הזו פועלת רק לקונטיזציה לפי טנזור. אנחנו עובדים על קצירה לפי ציר (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op משמש לציון כיווץ של משקלים בלבד לפעולה היברידית שמקבלת את ה-lhs בנקודת צפה ואת ה-rhs בסוגי נתונים מקובצים. הוא מבצע דקוונטיזציה של קלטים מקובצים לסוגים המפורטים שלהם ומבצע חישובים ב-float. סוג הרכיב של הטנזור הימני של המספרים הצפים וסוג הביטוי של הטנזור הימני המצטבר צריכים להיות זהים.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

חישובים ברשת

  • cross_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע 'cross_replica' למעלה.

  • cross_replica(replica_groups: Value) -> Value. אפשר לעיין בקטע 'cross_replica' למעלה.

  • cross_replica_and_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע 'cross_replica_and_partition' למעלה.

  • flattened_ids(replica_groups: Value) -> Value. ראו את הקטעflattened_id שלמעלה.

דינמיות

ערכי StableHLO יכולים לכלול גדלים דינמיים של מאפיינים, למשל tensor<?xi64>. עם זאת, לערכי StableHLO לא יכול להיות מספר דינמי של מאפיינים (דינמיזמה ללא דירוג, למשל tensor<*xi64>). לישויות ולתוצאות יש הרשאה להשתמש בגדלים של מאפיינים דינמיים, גם אם יש מגבלות על הגדלים. אם אפשר, האילוצים יאומתו באופן סטטי. אחרת, הם יושהו לזמן הריצה, ואי-התאמות יובילו להתנהגות לא מוגדרת. בהמשך מפורטות דוגמאות.

אי-התאמה בצורות עבור פעולות שאינן בסיסיות

נבחן את תוכנית הצעצוע הבאה:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

תוכנית כזו היא חריגה, כי לא מקובל לדעת את הצורה של התוצאה, אבל לא את צורת הקלט. עם זאת, זוהי תוכנית תקינה של StableHLO. אי אפשר לאמת באופן סטטי את הפעולה abs בתוכנית הזו, כי לא ידוע מהו הצורה המדויקת של האופרנד. עם זאת, הצורות תואמות בהחלט, וניתן לבדוק זאת באופן סטטי: ? יכול להפוך ל-2 בזמן הריצה, ולא תהיה בעיה. עם זאת, יכול להיות ש-? יהיה גם מספר שלם אחר, ובמקרה כזה ההתנהגות לא מוגדרת.

הערה: אם גודל המאפיין הוא דינמי בתוצאה, לא יכולה להיות התנהגות לא מוגדרת. אכן, אין גודל 'צפוי', ולכן לא יכולה להיות אי-התאמה.

אי-התאמה של צורות עבור פעולות בינאריות בסיסיות

נבחן את תוכנית הצעצועים הבאה:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

כשמדובר בפעולות בינאריות לפי רכיבים, הצורות של הקלט והתוצאה חייבות להיות זהות בסביבת זמן הריצה. בזמן הידור, המאפיינים הסטטיים חייבים להיות זהים, אחרת הם צריכים להיות רק תואמים. אם מימד כלשהו הוא דינמי בקלט, יכול להיות שתהיה התנהגות לא מוגדרת בזמן הריצה, כי יכול להיות שהגודל הדינמי לא יתאים לגודל התואם באופרנד האחר (בין אם הוא סטטי או דינמי). אם כל הקלט הוא סטטי, לא משנה אם התוצאה דינמית או לא: מאפיינים ידועים באופן סטטי ייבדקו באופן סטטי, ומאפיינים דינמיים לא מטילים אילוצים.

אי-התאמות בצורה של פעולות שמקבלות את צורת הפלט שלהן כאופרטנד

נבחן את תוכנית הצעצוע הבאה:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

הערכים באופרנד בזמן הריצה צריכים להתאים לצורת התוצאה, אחרת ההתנהגות לא מוגדרת. כלומר, בזמן הריצה הערך של %arg0 חייב להיות dense<[3, 4]> : tensor<2xi32>. אם אופרנד הצורה הוא קבוע, אפשר לאמת זאת באופן סטטי. אם צורת התוצאה היא דינמית לחלוטין, לא יכולה להיות אי-התאמה.