StableHLO הוא קבוצת פעולות לפעולות ברמה גבוהה (HLO) במודלים של למידת מכונה (ML). StableHLO פועל כשכבת ניוד בין מסגרות שונות של למידת מכונה (ML) ומהדרים של למידת מכונה: מסגרות של למידת מכונה שמפיקות תוכניות StableHLO תואמות למהדרים של למידת מכונה שצורכים תוכניות StableHLO.
המטרה שלנו היא לפשט ולהאיץ את פיתוח ה-ML על ידי יצירת יכולת פעולה הדדית רבה יותר בין מסגרות שונות של ML (כמו TensorFlow, JAX ו-PyTorch) ובין קומפיילרים של ML (כמו XLA ו-IREE). לכן, במסמך הזה מפורטות ההגדרות של שפת התכנות StableHLO.
המפרט הזה מחולק לשלושה חלקים עיקריים. בקטע Programs מוסבר על המבנה של תוכניות StableHLO, שמורכבות מפונקציות StableHLO, שבתורן מורכבות מאופרטורים של StableHLO. במבנה הזה, בקטע Ops מפורטת הסמנטיקה של פעולות נפרדות. בקטע Execution (הרצה) מוסבר על הסמנטיקה של כל הפעולות האלה שמופעלות יחד בתוכנית. בסוף, בקטע Notation (סימון) מוסבר על הסימון שבו נעשה שימוש לאורך המפרט.
כדי לראות את המפרט מגרסה קודמת של StableHLO, פותחים את המאגר בגרסה המתויגת הרלוונטית. לדוגמה, StableHLO v0.19.0 Spec. כדי לראות את השינויים שחלו בכל עדכון של גרסה משנית של StableHLO, אפשר לעיין ביומן הגרסאות ב-VhloDialect.td.
תוכניות
Program ::= {Func}
תוכניות StableHLO מורכבות ממספר כלשהו של פונקציות StableHLO.
למטה מופיעה דוגמה לתוכנית עם פונקציה @main שיש לה 3 קלטים (%image, %weights ו-%bias) ופלט אחד. הגוף של הפונקציה
כולל 6 פעולות.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
פונקציות
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
פונקציות StableHLO (שנקראות גם פונקציות בעלות שם) כוללות מזהה, קלט/פלט וגוף. בעתיד, אנחנו מתכננים להוסיף מטא-נתונים לפונקציות כדי לשפר את התאימות ל-HLO (#425, #626, #740, #744).
מזהים
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
מזהי StableHLO דומים למזהים בהרבה שפות תכנות, עם שני מאפיינים ייחודיים: 1) לכל המזהים יש סיגילים שמבחינים בין סוגים שונים של מזהים, 2) מזהי ערכים יכולים להיות מספריים לחלוטין כדי לפשט את יצירת התוכניות של StableHLO.
סוגים
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
סוגי StableHLO מסווגים לסוגי ערכים (שנקראים גם סוגים ממחלקה ראשונה) שמייצגים ערכי StableHLO, ולסוגים שהם לא ערכים שמתארים רכיבים אחרים של התוכנית. הסוגים ב-StableHLO דומים לסוגים בהרבה שפות תכנות, והמאפיין העיקרי שמייחד אותם הוא האופי הספציפי לתחום של StableHLO, שמוביל לתוצאות לא שגרתיות (למשל, סוגים סקלריים הם לא סוגי ערכים).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
סוגי טנסורים מייצגים טנסורים, כלומר מערכים רב-ממדיים. יש להם צורה וסוג רכיב, כאשר צורה מייצגת גדלים של ממדים לא שליליים או לא ידועים בסדר עולה של הממדים המתאימים (שנקראים גם צירים) שממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא דרגה. לדוגמה, tensor<2x3xf32> הוא סוג טנסור עם צורה 2x3 וסוג אלמנט f32. יש לו שני מאפיינים (או במילים אחרות, שני צירים) – מאפיין 0 ומאפיין 1 – שהגדלים שלהם הם 2 ו-3. הדירוג שלו הוא 2.
יכול להיות שצורות יהיו לא ידועות באופן חלקי או מלא (דינמיות), למשל tensor<?x2xf64> לא ידוע באופן חלקי ו-tensor<?x?xf64> לא ידוע באופן מלא. גדלים של מאפיינים דינמיים מיוצגים באמצעות ?. אי אפשר לבטל את הדירוג של צורות.
בעתיד, אנחנו מתכננים לבדוק אפשרות להרחיב את סוגי הטנסורים מעבר לגדלים של ממדים ולסוגי אלמנטים, למשל לכלול פריסות (#629) ודלילות (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| שם | סוג | מגבלות |
|---|---|---|
storage_type |
סוג מספר שלם | (C1-C3), (C8) |
storage_min |
קבוע מספרי שלם | (C1), (C3), (C7) |
storage_max |
קבוע מספרי שלם | (C2), (C3), (C7) |
expressed_type |
סוג נקודה צפה | (C4) |
quantization_dimension |
קבוע מספרי אופציונלי | (C10-C12) |
scales |
מספר משתנה של קבועים בשיטת נקודה צפה | (C4-C6), (C9), (C10), (C13) |
zero_points |
מספר משתנה של קבועים שלמים | (C7-C9) |
סוגי רכיבים שעברו קוונטיזציה מייצגים ערכים שלמים של סוג אחסון בטווח storage_min עד storage_max (כולל) שתואמים לערכים של נקודה צפה של סוג מבוטא. עבור ערך שלם נתון i,
אפשר לחשב את הערך התואם בנקודה צפה f באופן הבא:
f = (i - zero_point) * scale, כאשר scale ו-zero_point נקראים פרמטרים של קוונטיזציה. המאפיינים storage_min ו-storage_max הם אופציונליים בתחביר, אבל ערכי ברירת המחדל שלהם הם min_value(storage_type) ו-max_value(storage_type) בהתאמה. יש מגבלות על סוגי רכיבים שעברו קוונטיזציה:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) אם
is_empty(quantization_dimension), אזsize(scales) = 1. - (C11)
0 <= quantization_dimension.
בשלב הזה, QuantizationScale הוא קבוע בשיטת נקודה צפה, אבל יש עניין רב בסולמות שמבוססים על מספרים שלמים, שמיוצגים באמצעות מכפילים והזזות. אנחנו מתכננים לבדוק את האפשרות הזו בעתיד הקרוב
(#1404).
מתנהל דיון על הסמנטיקה של QuantizationZeroPoint, כולל הסוג, הערכים והשאלה אם יכולה להיות רק נקודת אפס אחת או כמה נקודות אפס פוטנציאליות בסוג של טנסור שעבר קוונטיזציה. על סמך התוצאות של הדיון הזה, יכול להיות שהמפרט לגבי אפס נקודות ישתנה בעתיד (#1405).
דיון נוסף מתנהל לגבי הסמנטיקה של QuantizationStorageMin
ושל QuantizationStorageMax כדי לקבוע אם צריך להגביל את הערכים האלה ואת הערכים של טנסורים שעברו קוונטיזציה (#1406).
בסוף, אנחנו מתכננים לבדוק איך לייצג סולמות לא ידועים ונקודות אפס, בדומה לאופן שבו אנחנו מתכננים לבדוק איך לייצג גדלים לא ידועים של מאפיינים (#1407).
סוגי טנסורים שעברו קוונטיזציה מייצגים טנסורים עם אלמנטים שעברו קוונטיזציה. הטנסורים האלה זהים בדיוק לטנסורים רגילים, אלא שהאלמנטים שלהם הם מסוג אלמנטים שעברו קוונטיזציה, במקום סוג אלמנטים רגיל.
בטנסורים שעברו קוונטיזציה, הקוונטיזציה יכולה להיות per-tensor, כלומר, עם scale ו-zero_point אחדים לכל הטנסור, או per-axis, כלומר, עם כמה scales ו-zero_points, זוג אחד לכל פרוסה של מימד מסוים quantization_dimension. באופן רשמי יותר, בטנזור t עם כימות לכל ציר, יש dim(t, quantization_dimension) פרוסות של quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] וכו'. כל הרכיבים בפרוסה ה-i משתמשים ב-scales[i] וב-zero_points[i] כפרמטרים של הכימות. יש מגבלות על סוגי טנסורים שעברו קוונטיזציה:
- לכימות לכל טנסור:
- אין אילוצים נוספים.
- לכימות לכל ציר:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
סוגי טוקנים מייצגים טוקנים, כלומר ערכים אטומים שנוצרים ונצרכים על ידי פעולות מסוימות. האסימונים משמשים להטלת סדר ביצוע על פעולות, כפי שמתואר בקטע ביצוע.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
סוגי מאגרים מייצגים מאגרים. לדוגמה, ב-XLA, מאגרי נתונים זמניים הם מערכים רב-ממדיים עם אחסון עקבי. בדומה לסוגי טנסורים, לסוגי מאגרים יש צורה וסוג רכיב, כאשר הצורה מייצגת גדלים של ממדים לא שליליים או לא ידועים בסדר עולה של הממדים המתאימים (שנקראים גם צירים) שממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא דרגה. לדוגמה, memref<2x3xf32> הוא סוג של מאגר עם צורה 2x3 וסוג רכיב f32. יש לו שני ממדים (או במילים אחרות, שני צירים) – ממד 0 וממד 1 – שהגדלים שלהם הם 2 ו-3. הדירוג שלו הוא 2.
אפשר להקצות מאגרי נתונים באמצעות custom_call עד CreateBuffer או Pin, ולבטל את ההקצאה באמצעות custom_call עד Unpin. רק פעולות custom_call יכולות לקרוא ולכתוב את התוכן בתוך מאגרי נתונים זמניים. פרטים נוספים מופיעים במאמר בנושא custom_call.
סוגי טאפלים מייצגים טאפלים, כלומר רשימות הטרוגניות. הטופלים הם תכונה מדור קודם שקיימת רק לצורך תאימות ל-HLO. ב-HLO, נעשה שימוש בטפילים כדי לייצג קלטים ופלט משתנים. ב-StableHLO יש תמיכה מובנית בקלט ובפלט עם מספר משתנה של ארגומנטים, והשימוש היחיד בטפלים ב-StableHLO הוא לייצוג מקיף של HLO ABI, שבו למשל T, tuple<T> ו-tuple<tuple<T>> עשויים להיות שונים באופן משמעותי בהתאם להטמעה מסוימת. בעתיד, אנחנו מתכננים לבצע שינויים ב-HLO ABI, שיאפשרו לנו להסיר סוגי טאפל מ-StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
סוגי רכיבים מייצגים רכיבים של סוגי טנסורים. בניגוד לשפות תכנות רבות, הסוגים האלה לא מוגדרים כסוגים ממחלקה ראשונה ב-StableHLO. המשמעות היא שתוכניות StableHLO לא יכולות לייצג ישירות ערכים מהסוגים האלה (כתוצאה מכך, מקובל לייצג ערכים סקלריים מהסוג T באמצעות ערכים טנסוריים אפסיים מהסוג tensor<T>).
- סוג בוליאני מייצג ערכים בוליאניים
trueו-false. - סוגי מספרים שלמים יכולים להיות עם סימן (
si) או ללא סימן (ui), ויכולים להיות בעלי אחת מרוחבי הסיביות הנתמכים (2,4,8,16,32או64). סוגיsiNעם סימן מייצגים ערכים של מספרים שלמים מ--2^(N-1)עד2^(N-1)-1כולל, וסוגיuiNללא סימן מייצגים ערכים של מספרים שלמים מ-0עד2^N-1כולל. - סוגי נקודה צפה יכולים להיות אחד מהסוגים הבאים:
- הערכים
f8E3M4,f8E4M3ו-f8E5M2הם מספרים שלמים של 8 ביט בפורמט נקודה צפה, בהתאם למוסכמות IEEE-754. -
f8E4M3FNו-f8E5M2, בהתאמה לקידודיםE4M3ו-E5M2של פורמט FP8 שמתואר במאמר פורמטים של FP8 ללמידה עמוקה. - הסוגים
f8E4M3FNUZו-f8E5M2FNUZתואמים לקידודיםE4M3ו-E5M2של פורמטי FP8 שמתוארים במאמר 8-bit Numerical Formats for Deep Neural Networks. - הסוג
f8E4M3B11FNUZשמתאים לקידודE4M3של פורמטי FP8 שמתוארים במאמר Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks. - הסוג
bf16שמתאים לפורמטbfloat16שמתואר במאמר BFloat16: The secret to high performance on Cloud TPUs. - הסוגים
f16,f32ו-f64תואמים לפורמטיםbinary16(דיוק חצי),binary32(דיוק יחיד) ו-binary64(דיוק כפול) שמתוארים בתקן IEEE 754. - הסוג
tf32תואם לפורמט TensorFloat32 ויש לו תמיכה מוגבלת ב-StableHLO. -
f4E2M1FN,f6E2M3FN,f6E3M2FNו-f8E8M0FNUסוגי MX (מיקרו-סקיילינג) שמתוארים במפרט הפורמטים של OCP Microscaling.
- הערכים
- סוגים מורכבים מייצגים ערכים מורכבים שיש להם חלק ממשי וחלק מדומה מאותו סוג אלמנט. הסוגים המורכבים הנתמכים הם
complex<f32>(שני החלקים הם מסוגf32) ו-complex<f64>(שני החלקים הם מסוגf64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
סוגי פונקציות מייצגים פונקציות בעלות שם ופונקציות אנונימיות. יש להם סוגי קלט (רשימת הסוגים בצד ימין של ->) וסוגי פלט (רשימת הסוגים בצד שמאל של ->). בשפות תכנות רבות, סוגי פונקציות הם מהסוג הראשון, אבל לא ב-StableHLO.
StringType ::= 'string'
סוג מחרוזת מייצג רצפים של בייטים. בניגוד לשפות תכנות רבות, סוג המחרוזת לא מוגדר כסוג בסיסי ב-StableHLO, והוא משמש רק לציון מטא-נתונים סטטיים של רכיבי תוכנית.
תפעול
פעולות StableHLO (שנקראות גם ops) מייצגות קבוצה סגורה של פעולות ברמה גבוהה במודלים של למידת מכונה. כמו שצוין למעלה, התחביר של StableHLO מבוסס במידה רבה על MLIR, שהוא לא בהכרח החלופה הכי נוחה, אבל אפשר לטעון שהוא הכי מתאים למטרה של StableHLO – ליצור יכולת פעולה הדדית טובה יותר בין מסגרות ML וקומפיילרים של ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
לפעולות StableHLO (שנקראות גם ops) יש שם, קלט/פלט וחתימה. השם מורכב מהקידומת stablehlo. ומנימוניק שמזהה באופן ייחודי אחת מהפעולות הנתמכות. בהמשך מופיעה רשימה מקיפה של כל הפעולות הנתמכות.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
פעולות צורכות קלטים ומפיקות פלטים. הקלט מסווג לערכי קלט (מחושבים במהלך ההפעלה), פונקציות קלט (מסופקות באופן סטטי, כי בפונקציות StableHLO לא מוגדרים ערכים מסוג first-class) ולמאפייני קלט (מסופקים גם באופן סטטי). סוגי הקלט והפלט שפעולה צורכת ומייצרת תלויים בסימן שלה. לדוגמה, הפעולה add צורכת 2 ערכי קלט ומפיקה ערך פלט אחד. לעומת זאת, הפעולה select_and_scatter צורכת 3 ערכי קלט, 2 פונקציות קלט ו-3 מאפייני קלט.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
פונקציות קלט (שנקראות גם פונקציות אנונימיות) דומות מאוד לפונקציות בעלות שם, אבל יש ביניהן הבדלים: 1) אין להן מזהה (ומכאן השם 'אנונימיות'), 2) הן לא מצהירות על סוגי פלט (סוגי הפלט נגזרים מהאופרטור return בתוך הפונקציה).
התחביר של פונקציות הקלט כולל חלק שלא נמצא בשימוש כרגע (ראו את Unusedהייצור שלמעלה), והוא נועד לתאימות ל-MLIR. ב-MLIR, יש מושג כללי יותר של 'אזורים' שיכולים לכלול כמה 'בלוקים' של פעולות שמחוברים באמצעות פעולות קפיצה. לבלוקים האלה יש מזהים שתואמים לUnused הייצור, כדי שאפשר יהיה להבדיל ביניהם.
ב-StableHLO אין אופרטורים של קפיצה, ולכן החלק התואם בתחביר של MLIR לא נמצא בשימוש (אבל הוא עדיין קיים).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
מאפייני קלט כוללים שם וערך שהוא אחת מהקבועים הנתמכים. הם הדרך העיקרית לציין מטא-נתונים סטטיים עבור רכיבי תוכנה. לדוגמה, האופרטור concatenate משתמש במאפיין dimension כדי לציין את המימד שלפיו ערכי הקלט שלו משורשרים. באופן דומה,
האופרטור slice משתמש בכמה מאפיינים כמו start_indices ו-limit_indices
כדי לציין את הגבולות שמשמשים לפילוח של ערך הקלט.
בשלב הזה, תוכניות StableHLO בשימוש נרחב מכילות לפעמים מאפיינים שלא מתוארים במסמך הזה. בעתיד, אנחנו מתכננים להוסיף את המאפיינים האלה ל-opset של StableHLO או לאסור את השימוש בהם בתוכניות של StableHLO. בינתיים, הנה רשימת המאפיינים האלה:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(מס' 619).output_operand_aliases(#740).- מטא-נתונים של מיקום (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
חתימת הפעולה מורכבת מסוגי כל ערכי הקלט (רשימת הסוגים בצד ימין של ->) ומסוגי כל ערכי הפלט (רשימת הסוגים בצד שמאל של ->). למעשה, סוגי הקלט הם מיותרים, וגם סוגי הפלט הם כמעט תמיד מיותרים (כי ברוב הפעולות של StableHLO אפשר להסיק את סוגי הפלט מהקלט). עם זאת, חתימת האופרטור היא חלק מכוון בתחביר של StableHLO לצורך תאימות ל-MLIR.
בהמשך מופיעה דוגמה לאופרטור שהנימוניק שלו הוא select_and_scatter. היא צורכת 3 ערכי קלט (%operand, %source ו-%init_value), 2 פונקציות קלט ו-3 מאפייני קלט (window_dimensions, window_strides ו-padding). שימו לב שהחתימה של הפעולה כוללת רק את הסוגים של ערכי הקלט שלה (אבל לא את הסוגים של פונקציות הקלט והמאפיינים שמוגדרים בשורה).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
ערכים קבועים
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
לקבועים של StableHLO יש ערך מילולי וסוג, שביחד מייצגים ערך של StableHLO. בדרך כלל, הסוג הוא חלק מהתחביר הקבוע, אלא אם הוא חד-משמעי (למשל, קבוע בוליאני הוא חד-משמעי מסוג i1, בעוד שקבוע של מספר שלם יכול להיות מכמה סוגים אפשריים).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
קבועים בוליאניים מייצגים ערכים בוליאניים true ו-false. לקבועים בוליאניים יש סוג i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
קבועים של מספרים שלמים מייצגים ערכים של מספרים שלמים באמצעות מחרוזות שמשתמשות בסימון עשרוני או הקסדצימלי. אין תמיכה בבסיסים אחרים, למשל בינארי או אוקטלי. ההגבלות על קבועים מסוג מספר שלם הן:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
קבועים של נקודה צפה מייצגים ערכים של נקודה צפה באמצעות מחרוזות שמשתמשות בסימון עשרוני או מדעי. בנוסף, אפשר להשתמש בסימון הקסדצימלי כדי לציין ישירות את הביטים הבסיסיים בפורמט הנקודה הצפה של הסוג המתאים. יש מגבלות על קבועים של נקודה צפה:
- (C1) אם משתמשים בסימון לא הקסדצימלי,
is_wellformed(float_literal, float_type). - (C2) אם משתמשים בסימון הקסדצימלי,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
קבועים מרוכבים מייצגים ערכים מרוכבים באמצעות רשימות של חלק ממשי (מופיע ראשון) וחלק דמיוני (מופיע שני). לדוגמה, (1.0, 0.0) : complex<f32> מייצג את 1.0 + 0.0i, ו-(0.0, 1.0) : complex<f32> מייצג את 0.0 + 1.0i. הסדר שבו החלקים האלה נשמרים בזיכרון מוגדר בהטמעה. יש הגבלות על קבועים מורכבים:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
קבועים של טנסור מייצגים ערכי טנסור באמצעות רשימות מקוננות שמוגדרות באמצעות סימון NumPy. לדוגמה, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> מייצג ערך טנזור עם המיפוי הבא מאינדקסים לרכיבים: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. הסדר שבו הרכיבים האלה נשמרים בזיכרון מוגדר בהטמעה. ההגבלות הבאות חלות על קבועים של Tensor:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), כאשר:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), כאשר:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- אחרת,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
קבועים של טנסור שעבר קוונטיזציה מייצגים ערכים של טנסור שעבר קוונטיזציה באמצעות אותה שיטה כמו קבועים של טנסור, עם אלמנטים שמוגדרים כקבועים של סוג האחסון שלהם. יש את המגבלות הבאות על קבועים של טנסור שעברו קוונטיזציה:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
מחרוזות ליטרליות מורכבות מבייטים שצוינו באמצעות תווים בפורמט ASCII ורצפי escape. הן לא תלויות בקידוד, ולכן הפרשנות של הבייטים האלה מוגדרת בהטמעה. למחרוזות ליטרליות יש סוג string.
תפעול
מוחלט
סמנטיקה
מבצע פעולת abs לפי רכיבים בטנזור operand ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים שלמים עם סימן: מודולוס של מספר שלם.
- למספרים ממשיים:
absמ-IEEE-754. - למספרים מרוכבים: מודולוס מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(abs, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור | (C1-C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור של מספר שלם עם סימן או מספר בשיטת נקודה צפה, או טנסור שעבר קוונטיזציה לכל טנסור | (C1-C2) |
מגבלות
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)מוגדר כך:-
complex_element_type(element_type(operand))ifis_complex(operand). baseline_element_type(operand)אחרת.
-
דוגמאות
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
הוספה
סמנטיקה
מבצעת חיבור של שני טנסורים lhs ו-rhs לפי רכיבים ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: OR לוגי.
- למספרים שלמים: חיבור של מספרים שלמים.
- למספרים ממשיים:
additionמ-IEEE-754. - למספרים מרוכבים: חיבור מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(add, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור או טנזור שעבר קוונטיזציה | (C1-C6) |
| (I2) | rhs |
טנזור או טנזור שעבר קוונטיזציה | (C1-C5), (C7) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1-C7) |
מגבלות
- אם הפעולה משתמשת בטנסורים לא מכומתים:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) אם
is_per_axis_quantized(lhs), אזquantization_dimension(lhs) = quantization_dimension(result). - (C7) אם
is_per_axis_quantized(rhs), אזquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
דוגמאות
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
סמנטיקה
ההגדרה הזו מבטיחה שהפעולות שיוצרות את inputs יבוצעו לפני פעולות שתלויות ב-result. הפעולה הזו לא מבצעת שום דבר, היא קיימת רק כדי ליצור תלות בנתונים מ-result ל-inputs.
קלט
| תווית | שם | סוג |
|---|---|---|
| (I1) | inputs |
מספר משתנה של token |
פלט
| שם | סוג |
|---|---|
result |
token |
דוגמאות
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
סמנטיקה
בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מחברת את הערכים של טנסורים operands מכל תהליך לאורך all_gather_dim ומפיקה טנסורים results.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
לאחר מכן, בכל process_group:
-
operands...@receiver = [operand@sender for sender in process_group]לכלreceiverב-process_group. -
results...@process = concatenate(operands...@process, all_gather_dim)לכלprocessב-process_group.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operands |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1), (C6) |
| (I2) | all_gather_dim |
קבוע מהסוג si64 |
(C1), (C6) |
| (I3) | replica_groups |
קבוע טנסור דו-ממדי מסוג si64 |
(C2-C4) |
| (I4) | channel_id |
קבוע מהסוג si64 |
(C5) |
| (I5) | use_global_device_ids |
קבוע מהסוג i1 |
(C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C6) |
מגבלות
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)מוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_replicasאם משתמשים ב-cross_replica_and_partition. -
num_processesאם משתמשים ב-flattened_ids.
-
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) אם
use_global_device_ids = true, אזchannel_id > 0. - (C6)
type(results...) = type(operands...)except:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
דוגמאות
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
סמנטיקה
בתוך כל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מחילה צמצום computation על ערכי הטנסורים operands מכל תהליך ומפיקה טנסורים results.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
לאחר מכן, בכל process_group:
-
results...@process[result_index] = exec(schedule)עבור עץ בינארי מסוים scheduleכאשר:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
-
scheduleהוא עץ בינארי שמוגדר בהטמעה, והמעבר שלו בסדר עולה הואto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operands |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C5), (C6) |
| (I2) | replica_groups |
מספר משתנה של קבועי טנסור חד-ממדיים מסוג si64 |
(C1-C3) |
| (I3) | channel_id |
קבוע מהסוג si64 |
(C4) |
| (I4) | use_global_device_ids |
קבוע מהסוג i1 |
(C4) |
| (I5) | computation |
פונקציה | (C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C6-C7) |
מגבלות
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)מוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_replicasאם משתמשים ב-cross_replica_and_partition. -
num_processesאם משתמשים ב-flattened_ids.
-
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) אם
use_global_device_ids = true, אזchannel_id > 0. - (C5)
computationהוא מסוג(tensor<E>, tensor<E>) -> (tensor<E>)כאשרis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
דוגמאות
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
סמנטיקה
בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מפצלת את הערכים של טנסורים operands לאורך split_dimension לחלקים, מפזרת את החלקים המפוצלים בין התהליכים, משרשרת את החלקים המפוזרים לאורך concat_dimension ומפיקה טנסורים results.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
-
cross_replica(replica_groups)ifchannel_id <= 0. -
cross_partition(replica_groups)ifchannel_id > 0.
לאחר מכן, בכל process_group:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)לכלsenderב-process_group.-
scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]כשreceiver_index = process_group.index(receiver). results...@process = concatenate(scattered_parts...@process, concat_dimension).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operands |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C3), (C9) |
| (I2) | split_dimension |
קבוע מהסוג si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
קבוע מהסוג si64 |
(C3), (C9) |
| (I4) | split_count |
קבוע מהסוג si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
קבוע טנסור דו-ממדי מסוג si64 |
(C5-C8) |
| (I6) | channel_id |
קבוע מהסוג si64 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C9) |
מגבלות
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)מוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_partitionsאם משתמשים ב-cross_partition.
-
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...)למעט, אםsplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
דוגמאות
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
וגם
סמנטיקה
מבצעת פעולת AND לפי רכיבים של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: AND לוגי.
- למספרים שלמים: ערך AND ברמת הביטים.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנסור מסוג בוליאני או מספר שלם | (C1) |
| (I2) | rhs |
טנסור מסוג בוליאני או מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג בוליאני או מספר שלם | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
סמנטיקה
מבצע פעולת atan2 לפי רכיבים בטנזורים lhs ו-rhs ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
atan2מ-IEEE-754. - למספרים מרוכבים: complex atan2.
- לסוגים כמותיים:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
| (I2) | rhs |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
דוגמאות
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
סמנטיקה
מחשב את הגרדיאנטים של כמה תשומות של batch_norm_training באמצעות backpropagation מ-grad_output, ומפיק טנסורים של grad_operand, grad_scale ו-grad_offset. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
לסוגים כמותיים, הפונקציה מבצעת dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1-C3), (C5) |
| (I2) | scale |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4), (C5) |
| (I3) | mean |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4) |
| (I4) | variance |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4) |
| (I5) | grad_output |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C2), (C3) |
| (I6) | epsilon |
קבוע מהסוג f32 |
|
| (I7) | feature_index |
קבוע מהסוג si64 |
(C1), (C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
grad_operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C2), (C3) |
grad_scale |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4) |
grad_offset |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4) |
מגבלות
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand, scale, mean, variance, grad_output, grad_operand,grad_scaleו-grad_offsetכוללים את אותוbaseline_element_type. - (C3) ל-
operand, ל-grad_outputול-grad_operandיש את אותו הצורה. - (C4) ל-
scale, mean, variance, grad_scaleו-grad_offsetיש את אותה צורה. - (C5)
size(scale) = dim(operand, feature_index).
דוגמאות
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
סמנטיקה
מבצעת נורמליזציה של טנסור operand בכל המימדים, למעט מימד feature_index, ומפיקה טנסור result. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1-C7) |
| (I2) | scale |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C3) |
| (I3) | offset |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C4) |
| (I4) | mean |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C5) |
| (I5) | variance |
טנזור חד-ממדי של נקודה צפה או סוג כמותי לכל טנזור | (C2), (C6) |
| (I6) | epsilon |
קבוע מהסוג f32 |
|
| (I7) | feature_index |
קבוע מהסוג si64 |
(C1), (C3-C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C2), (C7) |
מגבלות
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceו-resultכולן בעלות אותוbaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
סמנטיקה
מחשבת את הממוצע והשונות בכל המאפיינים, למעט המאפיין feature_index
ומנרמלת את טנסור operand כדי ליצור את הטנסורים output, batch_mean ו-batch_var. באופן רשמי יותר, אפשר לבטא את הפעולה הזו כפירוק לפעולות קיימות של StableHLO באמצעות תחביר Python באופן הבא:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
לסוגים כמותיים, הפונקציה מבצעת dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
| (I2) | scale |
טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור | (C2), (C3) |
| (I3) | offset |
טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור | (C2), (C4) |
| (I4) | epsilon |
קבוע מהסוג f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
קבוע מהסוג si64 |
(C1), (C3-C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
output |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C7) |
batch_mean |
טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור | (C2), (C5) |
batch_var |
טנזור חד-ממדי של נקודה צפה או של כימות לכל טנזור | (C2), (C6) |
מגבלות
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand, scale, offset, batch_mean, batch_varו-outputכוללים את אותוbaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
דוגמאות
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
סמנטיקה
מבצע פעולת bitcast על טנסור operand ומפיק טנסור result שבו הביטים של כל הטנסור operand מפורשים מחדש באמצעות הסוג של הטנסור result.
באופן רשמי יותר, בהינתן E = element_type(operand), E' = element_type(result),
ו-R = rank(operand):
- אם
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - אם
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - אם
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
הפונקציה bits מחזירה ייצוג בזיכרון של ערך נתון, וההתנהגות שלה מוגדרת על ידי ההטמעה, כי הייצוג המדויק של טנסורים מוגדר על ידי ההטמעה, וגם הייצוג המדויק של סוגי רכיבים מוגדר על ידי ההטמעה.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1-C2) |
מגבלות
- (C1) בהינתן
E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result)ו-R = rank(operand):- אם
num_bits(E') = num_bits(E),shape(result) = shape(operand). - אם
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)לכל0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- אם
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)לכל0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- אם
- (C2) אם
is_complex(operand) or is_complex(result), אזis_complex(operand) and is_complex(result).
דוגמאות
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
סמנטיקה
מרחיב את המאפיינים או את הדירוג של טנסור קלט על ידי שכפול הנתונים בטנסור operand ומפיק טנסור result. באופן רשמי יותר,
result[result_index] = operand[operand_index] כאשר לכל d ב-
axes(operand):
-
operand_index[d] = 0ifdim(operand, d) = 1. operand_index[d] = result_index[broadcast_dimensions[d]]אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C2-C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1), (C3), (C5-C6) |
מגבלות
- (C1)
element_type(result)מחושב לפי הנוסחה הבאה:-
element_type(operand), אם!is_per_axis_quantized(operand). -
element_type(operand), למעטquantization_dimension(operand),scales(operand)ו-zero_points(operand)שעשויים להיות שונים מ-quantization_dimension(result),scales(result)ו-zero_points(result)בהתאמה, אחרת.
-
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) לכל
dבaxes(operand):dim(operand, d) = 1אוdim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) אם
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- אם
dim(operand, quantization_dimension(operand)) = 1, אזscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
דוגמאות
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
כיסוי
סמנטיקה
הפונקציה מפיקה את הפלט מהרצה של פונקציה אחת בדיוק מתוך branches
בהתאם לערך של index. באופן רשמי יותר, result = selected_branch()
כאשר:
-
selected_branch = branches[index]if0 <= index < size(branches). selected_branch = branches[-1]אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | index |
טנזור אפס-ממדי מסוג si32 |
|
| (I2) | branches |
מספר משתנה של פונקציות | (C1-C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה | (C4) |
מגבלות
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
דוגמאות
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
סמנטיקה
מבצעת פעולת שורש שלישי לפי רכיבים על טנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
rootn(x, 3)מ-IEEE-754. - למספרים מרוכבים: שורש מעוקב מרוכב.
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(cbrt, operand, type(result))
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
סמנטיקה
מבצעת פעולת ceil על כל רכיב בטנזור operand ומפיקה טנזור result.
מיישמת את הפעולה roundToIntegralTowardPositive מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(ceil, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
סמנטיקה
מחשבת את פירוק Cholesky של קבוצת מטריצות.
באופן רשמי יותר, לכל i ב-index_space(result), result[i0, ..., iR-3, :, :] הוא פירוק שולסקי של a[i0, ..., iR-3, :, :], בצורה של מטריצה משולשת תחתונה (אם lower הוא true) או מטריצה משולשת עליונה (אם lower הוא false).
ערכי הפלט במשולש הנגדי, כלומר במשולש העליון או במשולש התחתון, בהתאמה, מוגדרים על ידי ההטמעה.
אם קיים i שבו מטריצת הקלט היא לא מטריצה הרמיטית חיובית מוגדרת, ההתנהגות לא מוגדרת.
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | a |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1-C3) |
| (I2) | lower |
קבוע מהסוג i1 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
דוגמאות
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
תיחום
סמנטיקה
הפונקציה מחזירה טנזור result שבו כל רכיב בטנזור operand מוגבל לערך מינימלי ומקסימלי. באופן רשמי יותר, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element), כאשר min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(clamp, min, operand, max, type(result)).
הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | min |
tensor או per-tensor quantized tensor | (C1), (C3) |
| (I2) | operand |
tensor או per-tensor quantized tensor | (C1-C4) |
| (I3) | max |
tensor או per-tensor quantized tensor | (C2), (C3) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C4) |
מגבלות
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
סמנטיקה
בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הערך של טנזור operand נשלח מתהליך המקור לתהליכי היעד ונוצר טנזור result.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
-
cross_replica(replica_groups)ifchannel_id <= 0. -
cross_partition(replica_groups)ifchannel_id > 0.
לאחר מכן, result@process מחושב לפי הנוסחה הבאה:
-
operand@process_groups[i, 0]אם קייםiכך שהתהליך נמצא ב-process_groups[i]. broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C3) |
| (I2) | replica_groups |
מספר משתנה של קבועי טנסור חד-ממדיים מסוג si64 |
(C1), (C2) |
| (I3) | channel_id |
קבוע מהסוג si64 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C3) |
מגבלות
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < NכאשרNמוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_partitionsאם משתמשים ב-cross_partition.
-
- (C3)
type(result) = type(operand).
דוגמאות
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
סמנטיקה
בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הערך של טנזור operand נשלח מתהליך המקור לתהליך היעד ונוצר טנזור result.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
-
cross_replica(source_target_pairs)ifchannel_id <= 0. -
cross_partition(source_target_pairs)ifchannel_id > 0.
לאחר מכן, result@process מחושב לפי הנוסחה הבאה:
-
operand@process_groups[i, 0], אם קייםiכך ש-process_groups[i, 1] = process. broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C5) |
| (I2) | source_target_pairs |
קבוע טנסור דו-ממדי מסוג si64 |
(C1-C4) |
| (I3) | channel_id |
קבוע מהסוג si64 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1) |
מגבלות
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, כאשרNמוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_partitionsאם משתמשים ב-cross_partition.
-
- (C5)
type(result) = type(operand).
דוגמאות
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
השוואה
סמנטיקה
מבצע השוואה בין רכיבים של טנסורים lhs ו-rhs בהתאם ל-comparison_direction ול-compare_type, ומפיק טנסור result.
המשמעות של הערכים comparison_direction ו-compare_type היא כדלקמן:
לסוגי רכיבים בוליאניים ומספריים שלמים:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
עבור סוגי רכיבים של נקודה צפה עם compare_type = FLOAT, האופרטור מטמיע את הפעולות הבאות של IEEE-754:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
עבור סוגי רכיבים של נקודה צפה עם compare_type = TOTALORDER, האופרטור משתמש בשילוב של פעולות totalOrder ו-compareQuietEqual מ-IEEE-754.
לסוגי רכיבים מורכבים, מתבצעת השוואה לקסיקוגרפית של זוגות (real, imag) באמצעות comparison_direction ו-compare_type שסופקו.
הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים כש-comparison_direction הוא GE, GT, LE או LT (מספר 560).
לסוגים כמותיים. הפונקציה מבצעת dequantize_compare(lhs, rhs,
comparison_direction).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1-C3) |
| (I2) | rhs |
tensor או per-tensor quantized tensor | (C1-C2) |
| (I3) | comparison_direction |
enum of EQ, NE, GE, GT, LE, and LT |
|
| (I4) | compare_type |
enum of FLOAT, TOTALORDER, SIGNED, and UNSIGNED |
(C3) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor of boolean type | (C2) |
מגבלות
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typeמוגדר כך:-
SIGNEDifis_signed_integer(element_type(lhs)). -
UNSIGNEDifis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)). -
FLOATאוTOTALORDERאםis_float(element_type(lhs)). -
FLOATifis_complex(element_type(lhs)).
-
דוגמאות
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
מורכב
סמנטיקה
מבצעת המרה של כל רכיב בנפרד לערך מרוכב מתוך זוג של ערכים ממשיים ומדומים, lhs ו-rhs, ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנסור מסוג f32 או f64 |
(C1-C3) |
| (I2) | rhs |
טנסור מסוג f32 או f64 |
(C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג מורכב | (C2), (C3) |
מגבלות
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)הוא מסוגcomplex<E>כאשרE = element_type(lhs).
דוגמאות
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
מורכב
סמנטיקה
פעולה שמורכבת מפעולות אחרות של StableHLO,
מקבלת inputs ו-composite_attributes ומפיקה results. הסמנטיקה של הפעולה מיושמת באמצעות המאפיין decomposition. אפשר להחליף את הפעולה composite בפירוק שלה בלי לשנות את הסמנטיקה של התוכנית. במקרים שבהם הוספת הפירוק לשורה לא מספקת את אותה סמנטיקה של פעולה, עדיף להשתמש ב-custom_call.
השדה version (ברירת המחדל היא 0) משמש לציון המועד שבו הסמנטיקה של רכיב מורכב משתנה.
קלט
| תווית | שם | סוג |
|---|---|---|
| (I1) | inputs |
מספר משתנה של ערכים |
| (I2) | name |
קבוע מהסוג string |
| (I3) | composite_attributes |
מילון מאפיינים |
| (I4) | decomposition |
קבוע מהסוג string |
| (I5) | version |
קבוע מהסוג si32 |
פלט
| שם | סוג |
|---|---|
results |
מספר משתנה של ערכים |
מגבלות
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
דוגמאות
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenate
סמנטיקה
משרשר את inputs לאורך המימד dimension באותו סדר של הארגומנטים שצוינו ומפיק טנסור result. בצורה רשמית יותר,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], כאשר:
id = d0 + ... + dk-1 + kd.-
dשווה ל-dimension, ו-d0, ... הם גדלי המאפייניםdשלinputs.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C6) |
| (I2) | dimension |
קבוע מהסוג si64 |
(C2), (C4), (C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C5-C6) |
מגבלות
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...))למעטdim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0])except for:dim(result, dimension) = dim(inputs[0], dimension) + ....
דוגמאות
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
קבוע
סמנטיקה
הפונקציה יוצרת טנסור output מקבוע value.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | value |
קבוע | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
output |
טנזור או טנזור שעבר קוונטיזציה | (C1) |
מגבלות
- (C1)
type(value) = type(output).
דוגמאות
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
להשלים המרה
סמנטיקה
מבצעת המרה מטיפוס אלמנט אחד לטיפוס אלמנט אחר בטנזור operand ומפיקה טנזור result.
בהמרות boolean-to-any-supported-type, הערך false מומר לאפס והערך true מומר לאחד. בהמרות של any-supported-type-to-boolean, ערך אפס מומר ל-false, וערכים שאינם אפס מומרים ל-true. בהמשך מוסבר איך זה עובד עם סוגים מורכבים.
בהמרות שכוללות מספר שלם למספר שלם, מספר שלם למספר נקודה צפה או מספר נקודה צפה למספר נקודה צפה, אם אפשר לייצג את ערך המקור בדיוק בסוג היעד, ערך התוצאה הוא הייצוג המדויק הזה. אחרת, ההתנהגות תוגדר בהמשך (#180).
בהמרות שכוללות floating-point-to-integer, החלק השברי מצומצם. אם אי אפשר לייצג את הערך שקוצץ בסוג היעד, ההתנהגות תוגדר בהמשך (#180).
המרות שכוללות מספרים מרוכבים למספרים מרוכבים מתבצעות באותו אופן כמו המרות של מספרים ממשיים למספרים ממשיים, כשממירים את החלקים הממשיים והדמיוניים.
בהמרות מסוג complex-to-any-other-type ו-any-other-type-to-complex, המערכת מתעלמת מהערך הדמיוני של המקור או מאפסת את הערך הדמיוני של היעד, בהתאמה. ההמרה של החלק האמיתי מתבצעת לפי המרות של נקודה צפה.
באופן עקרוני, הפעולה הזו יכולה לבצע דה-קוונטיזציה (המרה של טנסורים שעברו קוונטיזציה לטנסורים רגילים), קוונטיזציה (המרה של טנסורים רגילים לטנסורים שעברו קוונטיזציה) וקוונטיזציה מחדש (המרה בין טנסורים שעברו קוונטיזציה), אבל כרגע יש לנו פעולות ייעודיות לכך – uniform_dequantize לתרחיש השימוש הראשון ו-uniform_quantize לתרחישי השימוש השני והשלישי. יכול להיות שבעתיד שתי הפעולות האלה יאוחדו ל-convert (מס' 1576).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor | (C1) |
מגבלות
- (C1)
shape(operand) = shape(result).
דוגמאות
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
קונבולוציה
סמנטיקה
מחשבת מכפלה סקלרית בין חלונות של lhs ופרוסות של rhs ומפיקה result. בתרשים הבא מוצגות דוגמאות קונקרטיות שממחישות איך מחשבים את הרכיבים ב-result מ-lhs ומ-rhs.
באופן רשמי יותר, אפשר להגדיר מחדש את הקלט במונחים של lhs כדי להגדיר חלונות של lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
השינוי הזה מתבצע באמצעות פונקציות העזר הבאות:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).-
permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]כשj[d] = i[permutation[d]].
אם feature_group_count = 1 ו-batch_group_count = 1, אז לכל output_spatial_index ב-index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product כאשר:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). נראה שהתכונה הזו לא נמצאת בשימוש, ולכן אנחנו מתכננים להסיר אותה בעתיד (מס' 1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
אם feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
אם batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
עבור סוגים היברידיים של קוונטיזציה, הפעולה היא hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
טנזור או טנזור שעבר קוונטיזציה | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
קבוע טנסור חד-ממדי מסוג si64 |
(C2-C3), (C25) |
| (I4) | padding |
קבוע טנסור דו-ממדי מסוג si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
קבוע טנסור חד-ממדי מסוג si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
קבוע טנסור חד-ממדי מסוג si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
קבוע טנסור חד-ממדי מסוג i1 |
(C9) |
| (I8) | input_batch_dimension |
קבוע מהסוג si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
קבוע מהסוג si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
קבוע מהסוג si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
קבוע מהסוג si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
קבוע מהסוג si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
קבוע מהסוג si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
קבוע מהסוג si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
קבוע מהסוג si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST |
(C24) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C25-C28), (C30), (C32-34) |
מגבלות
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) בהינתן
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) בהינתן
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) בהינתן
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)מוגדר כך:-
dim(lhs, input_batch_dimension) / batch_group_countifresult_dim = output_batch_dimension. -
dim(rhs, kernel_output_feature_dimension)ifresult_dim = output_feature_dimension. num_windowsאחרת, כאשר:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
-
- (C26)
rank(result) = N. - אם הפעולה משתמשת בטנסורים לא מכומתים:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) אם
is_per_axis_quantized(rhs), אזquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) אם
is_per_axis_quantized(result), אזquantization_dimension(result) = output_feature_dimension. - אם
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) אם
is_per_tensor_quantized(rhs), אזis_per_tensor_quantized(result). - אם
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
דוגמאות
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
קוסינוס
סמנטיקה
מבצעת פעולת קוסינוס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
cosמ-IEEE-754. - למספרים מרוכבים: קוסינוס מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(cosine, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
סמנטיקה
מבצעת ספירה של מספר הביטים המובילים של אפס בכל רכיב בטנזור operand
ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם | (C1) |
מגבלות
- (C1)
type(operand) = type(result).
דוגמאות
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
סמנטיקה
מכיל פעולה שמוגדרת בהטמעה call_target_name שמקבלת את inputs ואת called_computations ומפיקה את results. אפשר להשתמש ב-has_side_effect, backend_config ו-api_version כדי לספק מטא-נתונים נוספים שמוגדרים על ידי ההטמעה.
בשלב הזה, הפעולה הזו מכילה אוסף לא מאורגן של מטא-נתונים שמשקפים את ההתפתחות האורגנית של הפעולה המקבילה לה בקומפיילר XLA. בעתיד, אנחנו מתכננים לאחד את המטא-נתונים האלה (#741).
קלט
| תווית | שם | סוג |
|---|---|---|
| (I1) | inputs |
מספר משתנה של ערכים |
| (I2) | call_target_name |
קבוע מהסוג string |
| (I3) | has_side_effect |
קבוע מהסוג i1 |
| (I4) | backend_config |
קבוע מהסוג string או מילון מאפיינים |
| (I5) | api_version |
קבוע מהסוג si32 |
| (I6) | called_computations |
מספר משתנה של קבועים מהסוג string |
| (I7) | output_operand_aliases |
מציינים את החלקים של ה-aliasing בפלטים ובאופרנדים |
פלט
| שם | סוג |
|---|---|
results |
מספר משתנה של ערכים |
(XLA GPU Support) Special custom_call targets
יש שלושה סוגים מיוחדים של call_target_name שקשורים ל-buffer:
CreateBuffer יוצר buffer לא מאותחל, Pin יוצר buffer מאותחל ו-Unpin מבטל את ההקצאה של buffer ומחזיר את התוכן של buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
כינוי
יכול להיות שחלק מהפעולות של custom_call ידרשו שחלק מהפלט וחלק מהאופרנדים ישתפו את אותו הזיכרון. אפשר לבטא את זה באמצעות
output_operand_aliases. ייצוג של זוג כינויים מורכב מרשימה של אינדקסים של טאפלים של פלט שמייצגים את חלק הפלט, ומ-operand_index יחד עם רשימה של אינדקסים של טאפלים של אופרנד שמייצגים את חלק האופרנד. רשימת הפלט או המדדים של טופל האופרנד ריקה אם הסוג המתאים הוא לא סוג tuple, והיא יכולה להיות ארוכה ככל שרוצים עבור סוג טופל מקונן. הייצוג הזה דומה לייצוג הכינוי XLA.
הסוג של חלק הפלט וחלק הקלט בצמד כינויים חייב להיות זהה. בפעולות custom_call שאינן קריאה אל CreateBuffer, Pin ו-Unpin, אופרנד buffer יכול להופיע בזוג אחד לכל היותר של כינויים, ופלט buffer חייב להופיע בזוג אחד של כינויים.
דוגמאות
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
חילוק
סמנטיקה
מבצעת חילוק של כל רכיב בטנזור המחולק lhs בכל רכיב בטנזור המחלק rhs, ומחזירה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים שלמים: חלוקת מספרים שלמים שמניבה את המנה האלגברית בלי החלק השברי.
- למספרים ממשיים:
divisionמ-IEEE-754. - למספרים מרוכבים: חילוק מרוכב.
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
| (I2) | rhs |
טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
דוגמאות
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
סמנטיקה
מחשב מכפלה סקלרית בין פרוסות של lhs לבין פרוסות של rhs ומפיק טנזור result.
באופן רשמי יותר, result[result_index] = dot_product, כאשר:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].-
result_batching_index + result_lhs_index + result_rhs_index = result_indexכאשרsize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)ו-size(result_rhs_index) = size(rhs_result_dimensions). transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
עבור סוגים היברידיים של קוונטיזציה, הפעולה היא hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config שולט בפשרה בין מהירות לדיוק בחישובים במערכות עורפיות של מאיצים. אפשר להשתמש באחד מהערכים הבאים (בשלב הזה, הסמנטיקה של ערכי ה-enum האלה לא מוגדרת באופן מלא, אבל אנחנו מתכננים לטפל בזה ב-#755):
-
DEFAULT: החישוב הכי מהיר, אבל הקירוב הכי פחות מדויק למספר המקורי. HIGH: חישוב איטי יותר, אבל קירוב מדויק יותר למספר המקורי.-
HIGHEST: החישוב הכי איטי, אבל הקירוב הכי מדויק למספר המקורי.
DotAlgorithm מגדיר את המאפיינים העיקריים של האלגוריתם שמשמש להטמעה של פעולת הנקודה, ומגדיר גם את הדיוק. אם שדות מאפייני האלגוריתם
מוגדרים, הערך של precision_config חייב להיות DEFAULT. DotAlgorithms
אין להם ערך ברירת מחדל, כי פרמטרי ברירת המחדל מוגדרים בהטמעה. לכן, אפשר להגדיר את כל השדות של אלגוריתם הנקודות לערך None כדי לציין אלגוריתם נקודות ריק, שבמקומו ישמש הערך precision_config.
השדות DotAlgorithm כוללים:
-
lhs_precision_typeו-rhs_precision_type, רמות הדיוק שאליהן מעוגלים הצד הימני והצד השמאלי של הפעולה. סוגי הדיוק לא תלויים בסוגי האחסון של נתוני הקלט והפלט. accumulation_typeרמת הדיוק שמשמשת לחישוב המצטבר.-
lhs_component_count,rhs_component_countו-num_primitive_operationsרלוונטיים כשאנחנו מפעילים אלגוריתם שמפרק את הצד הימני או השמאלי של המשוואה למספר רכיבים ומבצע מספר פעולות מכפלה סקלרית 'פרימיטיביות' על הערכים האלה – בדרך כלל כדי לדמות דיוק גבוה יותר (לדוגמה, שימוש בסוג הנתונים של בינה מלאכותית bfloat16 לחישובים ברמת דיוק גבוהה יותר: bf16_6x tf32_3x וכו'). עבור אלגוריתמים ללא פירוק, הערכים האלה צריכים להיות1. -
allow_imprecise_accumulationכדי לציין אם מותר לצבור נתונים ברמת דיוק נמוכה יותר בחלק מהשלבים (למשלCUBLASLT_MATMUL_DESC_FAST_ACCUM).
מאפיינים DotAlgorithm לדוגמה:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
ההטמעות יקבעו אילו שילובים נתמכים. באופן כללי, אין ערובה לכך שכל אלגוריתם נתמך בכל סוג של מאיץ על ידי הצרכן של StableHLO. אם אלגוריתם מסוים לא נתמך, צריך להציג שגיאה ולא לחזור לאלגוריתם חלופי. האימות של StableHLO יספק אימות מיטבי, וימנע שימוש באלגוריתמים שלא ידוע שהם נתמכים בחומרה כלשהי.
במאמר xla_data.proto > Algorithm מפורטים כמה ערכים נתמכים של אלגוריתמים. בכרטיס מספר 2483 מתואר התוכנית ליצירת מסמך מרכזי בנושא אלגוריתמים נתמכים לפי קצה עורפי.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
טנזור או טנזור שעבר קוונטיזציה | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType או TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType או TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType או TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
קבוע מהסוג si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
קבוע מהסוג si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
קבוע מהסוג si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
קבוע מהסוג bool |
(C21) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C12), (C14), (C18-C20) |
מגבלות
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - אם הפעולה משתמשת בטנסורים לא מכומתים:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) אם
is_per_axis_quantized(rhs), אזquantization_dimension(rhs)לא ב-rhs_contracting_dimensions. - אם
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) אם
is_per_tensor_quantized(rhs), אזis_per_tensor_quantized(result). - אם
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- אם
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
דוגמאות
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות ל-broadcast_in_dim op, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_dimensions.
הפעולה מקבלת גם מאפיינים אופציונליים known_expanding_dimensions, known_nonexpanding_dimensions כדי לבטא ידע סטטי לגבי התנהגות ההרחבה של המאפיינים.
אם לא מציינים מאפיינים, המערכת מניחה שכל המאפיינים יכולים להתרחב.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
טנזור חד-ממדי מסוג מספר שלם | (C7) |
| (I3) | broadcast_dimensions |
טנזור קבוע חד-ממדי מסוג מספר שלם | (C2-C6) |
| (I4) | known_expanding_dimensions |
טנזור קבוע חד-ממדי מסוג מספר שלם | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
טנזור קבוע חד-ממדי מסוג מספר שלם | (C8-C9) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1), (C3), (C5-C7) |
מגבלות
- (C1)
element_type(result)מחושב לפי הנוסחה הבאה:-
element_type(operand), אם!is_per_axis_quantized(operand). -
element_type(operand), למעטquantization_dimension(operand),scales(operand)ו-zero_points(operand)שעשויים להיות שונים מ-quantization_dimension(result),scales(result)ו-zero_points(result)בהתאמה, אחרת.
-
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) לכל
dבaxes(operand):dim(operand, d) = 1אוdim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) אם
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- אם
dim(operand, quantization_dimension(operand)) = 1, אזscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
דוגמאות
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות לפעולת הקונבולוציה, אבל הריפוד מוגדר באופן דינמי באמצעות padding.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
טנזור או טנזור שעבר קוונטיזציה | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
טנזור דו-ממדי מסוג מספר שלם | (C4) |
| (I4) | window_strides |
קבוע טנסור חד-ממדי מסוג si64 |
(C2-C3) |
| (I5) | lhs_dilation |
קבוע טנסור חד-ממדי מסוג si64 |
(C5-C6) |
| (I6) | rhs_dilation |
קבוע טנסור חד-ממדי מסוג si64 |
(C7-C8) |
| (I7) | window_reversal |
קבוע טנסור חד-ממדי מסוג i1 |
(C9) |
| (I8) | input_batch_dimension |
קבוע מהסוג si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
קבוע מהסוג si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
קבוע מהסוג si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
קבוע מהסוג si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
קבוע מהסוג si64 |
(C20) |
| (I15) | output_feature_dimension |
קבוע מהסוג si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C19-C20) |
| (I17) | feature_group_count |
קבוע מהסוג si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
קבוע מהסוג si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
מספר משתנה של סוגי enum של DEFAULT, HIGH ו-HIGHEST |
(C24) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C25-C27), (C29), (C31-C33) |
מגבלות
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) בהינתן
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) בהינתן
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) בהינתן
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)מוגדר כך:-
dim(lhs, input_batch_dimension) / batch_group_countifresult_dim = output_batch_dimension. -
dim(rhs, kernel_output_feature_dimension)ifresult_dim = output_feature_dimension. num_windowsאחרת, כאשר:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
-
- (C26)
rank(result) = N. - אם הפעולה משתמשת בטנסורים לא מכומתים:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- אם הפעולה משתמשת בטנסורים שעברו קוונטיזציה:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) אם
is_per_axis_quantized(rhs), אזquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) אם
is_per_axis_quantized(result), אזquantization_dimension(result) = output_feature_dimension. - אם
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) אם
is_per_tensor_quantized(rhs), אזis_per_tensor_quantized(result). - אם
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
דוגמאות
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות שלה לפעולה gather, כאשר slice_sizes מוגדר באופן דינמי כערך.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
טנזור מסוג מספר שלם | (C2), (C3), (C13) |
| (I3) | slice_sizes |
טנזור חד-ממדי מסוג מספר שלם | (C8), (C11-C13) |
| (I4) | offset_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
קבוע טנסור חד-ממדי מסוג si64 |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
קבוע מהסוג si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
קבוע מהסוג i1 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C5), (C13-C14) |
מגבלות
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)כאשר:batch_dim_sizes = shape(start_indices)חוץ מזה שגודל המאפייןstart_indicesשמתאים ל-index_vector_dimלא נכלל.-
offset_dim_sizes = shape(slice_sizes), למעט גדלי המימדים ב-slice_sizesשתואמים ל-collapsed_slice_dims. -
combineממקם אתbatch_dim_sizesעל צירים שתואמים ל-batch_dimsואתoffset_dim_sizesעל צירים שתואמים ל-offset_dims.
- (C14)
element_type(operand) = element_type(result).
דוגמאות
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות ל-iota op, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_shape.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | output_shape |
טנזור חד-ממדי מסוג מספר שלם | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C2) |
מגבלות
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
דוגמאות
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות לפעולה pad, אבל הערכים edge_padding_low, edge_padding_high ו-interior_padding מוגדרים באופן דינמי כערכים.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
טנזור אפס-ממדי או טנזור שעבר קוונטיזציה ברמת הטנזור | (C1) |
| (I3) | edge_padding_low |
טנזור חד-ממדי מסוג מספר שלם | (C1), (C4) |
| (I4) | edge_padding_high |
טנזור חד-ממדי מסוג מספר שלם | (C1), (C4) |
| (I5) | interior_padding |
טנזור חד-ממדי מסוג מספר שלם | (C2-C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C3-C6) |
מגבלות
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
דוגמאות
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
סמנטיקה
הפעולה הזו זהה מבחינת הפונקציונליות ל-op reshape, אבל צורת התוצאה מוגדרת באופן דינמי באמצעות output_shape.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C3) |
| (I2) | output_shape |
טנזור חד-ממדי מסוג מספר שלם | (C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1-C4) |
מגבלות
- (C1)
element_type(result)מחושב לפי הנוסחה הבאה:-
element_type(operand), אם!is_per_axis_quantized(operand). element_type(operand)חוץ מהמקרים שבהםquantization_dimension(operand)וquantization_dimension(result)שונים.
-
- (C2)
size(operand) = size(result). - (C3) אם
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
דוגמאות
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
סמנטיקה
מחלקים את operand לחלקים באמצעות חישוב דינמי של אינדקסים התחלתיים ומפיקים טנזור result. start_indices מכיל את אינדקס ההתחלה של הפלח לכל מאפיין, בכפוף להתאמה פוטנציאלית, ו-slice_sizes מכיל את הגודל של הפלח לכל מאפיין. באופן רשמי יותר,
result[result_index] = operand[operand_index] כאשר:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C2), (C4) |
| (I2) | start_indices |
מספר משתנה של טנסורים אפס-ממדיים מסוג מספר שלם | (C2), (C3) |
| (I3) | slice_sizes |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1), (C5) |
מגבלות
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
דוגמאות
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
סמנטיקה
יוצר טנזור result ששווה לטנזור operand, למעט הפרוסה שמתחילה ב-start_indices, שמעודכנת עם הערכים ב-update.
באופן רשמי יותר, result[result_index] מוגדר כך:
update[update_index]אם0 <= update_index < shape(update)כאשר:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
operand[result_index]אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1-C4), (C6) |
| (I2) | update |
tensor או per-tensor quantized tensor | (C2), (C3), (C6) |
| (I3) | start_indices |
מספר משתנה של טנסורים אפס-ממדיים מסוג מספר שלם | (C4), (C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1) |
מגבלות
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
דוגמאות
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
מעריכיות
סמנטיקה
מבצעת פעולה מעריכית על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
expמ-IEEE-754. - למספרים מרוכבים: חזקה מרוכבת.
- לסוגים כמותיים:
dequantize_op_quantize(exponential, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
סמנטיקה
מבצע פעולת חיסור של 1 מכל רכיב של טנסור operand, ומחזיר טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
expm1מ-IEEE-754. - למספרים מרוכבים: פונקציית האקספוננט המרוכב פחות אחד.
- לסוגים כמותיים:
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
סמנטיקה
מבצעת את התמרת פורייה הישירה וההפוכה עבור קלט/פלט ממשי ומרוכב.
fft_type הוא אחד מהבאים:
-
FFT: העברה של FFT מורכב למורכב. -
IFFT: FFT מורכב-למורכב הפוך. -
RFFT: העברה של FFT מורכב בזמן אמת. -
IRFFT: FFT הפוך של מספרים ממשיים למספרים מרוכבים (כלומר, מקבל מספרים מרוכבים ומחזיר מספרים ממשיים).
באופן רשמי יותר, בהינתן הפונקציה fft שמקבלת טנסורים חד-ממדיים של סוגים מורכבים כקלט, יוצרת טנסורים חד-ממדיים של אותם סוגים כפלט ומחשבת את התמרת פורייה הדיסקרטית:
במקרה של fft_type = FFT, result מוגדר כתוצאה הסופית של סדרת חישובים שבהם L = size(fft_length). לדוגמה, עבור L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
בנוסף, בהינתן הפונקציה ifft שיש לה אותה חתימת סוג והיא מחשבת את ההופכי של fft:
במקרה של fft_type = IFFT, result מוגדרת כהיפוך של החישובים של fft_type = FFT. לדוגמה, עבור L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
בנוסף, נתונה הפונקציה rfft שמקבלת טנסורים חד-ממדיים של סוגי נתונים מסוג נקודה צפה, יוצרת טנסורים חד-ממדיים של סוגים מורכבים עם אותה סמנטיקה של נקודה צפה ופועלת באופן הבא:
rfft(real_operand) = truncated_resultwherecomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
(כשמחשבים את התמרת פורייה הדיסקרטית עבור אופרנדים ממשיים, N/2 + 1 הרכיבים הראשונים של התוצאה מגדירים באופן חד-משמעי את שאר התוצאה, ולכן התוצאה של rfft נחתכת כדי להימנע מחישוב של רכיבים מיותרים).
במקרה של fft_type = RFFT, result מוגדר כתוצאה הסופית של סדרת חישובים שבהם L = size(fft_length). לדוגמה, עבור L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
לבסוף, בהינתן הפונקציה irfft שיש לה אותה חתימת סוג והיא מחשבת את ההופכי של rfft:
במקרה של fft_type = IRFFT, result מוגדרת כהיפוך של החישובים של fft_type = RFFT. לדוגמה, עבור L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מסוג מרוכב | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
enum of FFT, IFFT, RFFT, and IRFFT |
(C2), (C5) |
| (I3) | fft_length |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C3), (C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מסוג מרוכב | (C2), (C4), (C5) |
מגבלות
- (C1)
size(fft_length) <= rank(operand). - (C2) הקשר בין סוגי הרכיבים
operandו-resultמשתנה:- אם
fft_type = FFT, element_type(operand)ו-element_type(result)הם מאותו סוג מורכב. - אם
fft_type = IFFT, element_type(operand)ו-element_type(result)הם מאותו סוג מורכב. - אם
fft_type = RFFT,element_type(operand)הוא סוג של נקודה צפה ו-element_type(result)הוא סוג מורכב של אותה סמנטיקה של נקודה צפה. - אם
fft_type = IRFFT,element_type(operand)הוא סוג מורכב ו-element_type(result)הוא סוג של נקודה צפה עם אותה סמנטיקה של נקודה צפה.
- אם
- (C3)
1 <= size(fft_length) <= 3. - (C4) אם בין
operandל-resultיש טנזורrealמסוג נקודה צפה, אזshape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand)except for:- אם
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - אם
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- אם
דוגמאות
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
קומה
סמנטיקה
מבצעת פעולת floor על כל רכיב בטנזור operand ומפיקה טנזור result.
מיישמת את הפעולה roundToIntegralTowardNegative מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(floor, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
לרכז
סמנטיקה
אוספת פרוסות מטנזור operand מההיסטים שצוינו ב-start_indices ומפיקה טנזור result.
בתרשים הבא מוצג מיפוי של רכיבים ב-result לרכיבים ב-operand באמצעות דוגמה קונקרטית. בתרשים מוצגים כמה מדדים לדוגמה result
והסבר מפורט לאילו מדדים operand הם מתאימים.
באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].- ההגדרה של
start_indexהיא:-
start_indices[bi0, ..., :, ..., biN]כאשרbiהם רכיבים נפרדים ב-batch_indexו-:מוכנס באינדקסindex_vector_dim, אםindex_vector_dim<rank(start_indices). [start_indices[batch_index]]אחרת.
-
- for
d_operandinaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])ifd_operand = start_index_map[d_start].full_start_index[d_operand] = 0אחרת.
- for
d_operandinaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]אםd_operand = operand_batching_dims[i_batching]וגםd_start = start_indices_batching_dims[i_batching].full_batching_index[d_operand] = 0אחרת.
offset_index = result_index[offset_dims...].-
full_offset_index = [oi0, ..., 0, ..., oiN]כאשרoiהם רכיבים נפרדים ב-offset_index, ו-0מוכנס באינדקסים מ-collapsed_slice_dimsועדoperand_batching_dims. operand_index = full_start_index + full_batching_index + full_offset_index.
אם הערך של indices_are_sorted הוא true, ההטמעה יכולה להניח שהערכים של start_indices ממוינים לפי start_index_map. אחרת, ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result),
full_start_index(i1) <= full_start_index(i2).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
טנזור מסוג מספר שלם | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C13-C17) |
| (I7) | start_index_map |
קבוע טנסור חד-ממדי מסוג si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
קבוע מהסוג si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
קבוע טנסור חד-ממדי מסוג si64 |
(C9), (C12), (C20-C22) |
| (I10) | indices_are_sorted |
קבוע מהסוג i1 |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C5), (C22-C23) |
מגבלות
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)כאשר:batch_dim_sizes = shape(start_indices)חוץ מזה שגודל המאפייןstart_indicesשמתאים ל-index_vector_dimלא נכלל.-
offset_dim_sizes = slice_sizesexcept that the dimension sizes inslice_sizescorresponding tocollapsed_slice_dimsandoperand_batching_dimsare not included. -
combineממקם אתbatch_dim_sizesעל צירים שתואמים ל-batch_dimsואתoffset_dim_sizesעל צירים שתואמים ל-offset_dims.
- (C23)
element_type(operand) = element_type(result).
דוגמאות
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
סמנטיקה
הפונקציה מחזירה את הגודל של dimension שצוין ב-operand. באופן רשמי יותר,
result = dim(operand, dimension). הסמנטיקה מתייחסת רק לרכיב הצורה של הסוג. סוג הרכיב יכול להיות כל דבר.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1) |
| (I2) | dimension |
קבוע מהסוג si64 |
(C1) |
פלט
| שם | סוג |
|---|---|
result |
טנזור אפס-ממדי מסוג si32 |
מגבלות
- (C1)
0 <= dimension < rank(operand).
דוגמאות
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
סמנטיקה
מחזירה את הרכיב במיקום index של ה-tuple operand ויוצרת result. באופן רשמי יותר, result = operand[index].
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tuple | (C1), (C2) |
| (I2) | index |
קבוע מהסוג si32 |
(C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
כל ערך | (C2) |
מגבלות
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
דוגמאות
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
אם
סמנטיקה
הפונקציה מפיקה את הפלט מהפעלת פונקציה אחת בדיוק מתוך true_branch או false_branch, בהתאם לערך של pred. באופן רשמי יותר, result =
pred ? true_branch() : false_branch().
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | pred |
טנזור אפס-ממדי מסוג i1 |
|
| (I2) | true_branch |
פונקציה | (C1-C3) |
| (I3) | false_branch |
פונקציה | (C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה | (C3) |
מגבלות
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
דוגמאות
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
סמנטיקה
הפונקציה מחלצת את החלק המדומה, לפי רכיב, מ-operand ומפיקה טנסור result. באופן רשמי יותר, לכל רכיב x:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מסוג מרוכב | (C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג נקודה צפה | (C1), (C2) |
מגבלות
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)מוגדר כך:-
complex_element_type(element_type(operand))ifis_complex(operand). element_type(operand)אחרת.
-
דוגמאות
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
בפיד
סמנטיקה
קורא נתונים מהפיד ומפיק results.
הסמנטיקה של infeed_config מוגדרת בהטמעה.
results מורכב מערכי מטען ייעודי (payload) שמופיעים ראשונים וטוקן שמופיע אחרון. בעתיד, אנחנו מתכננים לפצל את מטען הנתונים ואת האסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).
קלט
| תווית | שם | סוג |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
קבוע מהסוג string |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה | (C1-C3) |
מגבלות
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])אוis_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
דוגמאות
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
סמנטיקה
ממלאת טנסור output בערכים בסדר עולה החל מאפס לאורך המימד iota_dimension. באופן רשמי יותר,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
output |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
0 <= iota_dimension < rank(output).
דוגמאות
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
סמנטיקה
הפונקציה מבצעת בדיקה של כל רכיב כדי לראות אם הערך ב-x הוא סופי (כלומר, הוא לא +Inf, -Inf או NaN) ומפיקה טנסור y. מיישמת את הפעולה isFinite
מהמפרט IEEE-754. בסוגים כמותיים, התוצאה היא תמיד true.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | x |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
y |
tensor of boolean type | (C1) |
מגבלות
- (C1)
shape(x) = shape(y).
דוגמאות
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
סמנטיקה
מבצעת פעולת לוגריתם לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
logמ-IEEE-754. - למספרים מרוכבים: לוגריתם מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(log, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
סמנטיקה
מבצעת פעולה של לוגריתם פלוס אחד על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
logp1מ-IEEE-754. - למספרים מרוכבים:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - לסוגים כמותיים:
dequantize_op_quantize(log_plus_one, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
לוגיסטי
סמנטיקה
מבצע פעולה לוגיסטית ברמת הרכיב בטנזור operand ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
division(1, addition(1, exp(-x)))מ-IEEE-754. - למספרים מרוכבים: לוגיסטיקה מורכבת.
- לסוגים כמותיים:
dequantize_op_quantize(logistic, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
מפה
סמנטיקה
מחילים פונקציית מיפוי computation על inputs לאורך dimensions ומפיקים טנזור result.
באופן רשמי יותר, result[result_index] = computation(inputs...[result_index]).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C4) |
| (I2) | dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C3) |
| (I3) | computation |
פונקציה | (C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1), (C4) |
מגבלות
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationהוא מסוג(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>כאשרEi = element_type(inputs[i])ו-E' = element_type(result).
דוגמאות
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
מקסימום
סמנטיקה
מבצע פעולת מקסימום לפי רכיבים בטנסורים lhs ו-rhs ומפיק טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: OR לוגי.
- למספרים שלמים: מספר שלם מקסימלי.
- למספרים ממשיים:
maximumמ-IEEE-754. - למספרים מרוכבים: המקסימום הלקסיקוגרפי של הצמד
(real, imaginary). הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560). - לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1) |
| (I2) | rhs |
tensor או per-tensor quantized tensor | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1) |
מגבלות
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
דוגמאות
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
מינימום
סמנטיקה
מבצעת פעולת min לפי רכיבים בטנסורים lhs ו-rhs ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: AND לוגי.
- למספרים שלמים: מספר שלם מינימלי.
- למספרים ממשיים:
minimumמ-IEEE-754. - למספרים מורכבים: המינימום הלקסיקוגרפי עבור הצמד
(real, imaginary). הגדרת סדר למספרים מרוכבים כוללת סמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה הזו (#560). - לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1) |
| (I2) | rhs |
tensor או per-tensor quantized tensor | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1) |
מגבלות
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
דוגמאות
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
כפל
סמנטיקה
מבצעת מכפלה של שני טנסורים lhs ו-rhs לפי רכיבים ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: AND לוגי.
- למספרים שלמים: כפל של מספרים שלמים.
- למספרים ממשיים:
multiplicationמ-IEEE-754. - למספרים מרוכבים: כפל מרוכב.
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
tensor או per-tensor quantized tensor | (C1) |
| (I2) | rhs |
tensor או per-tensor quantized tensor | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
החלפת חיובי/שלילי
סמנטיקה
מבצעת שלילת כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים שלמים עם סימן: שלילה של מספר שלם.
- למספרים שלמים ללא סימן: bitcast למספר שלם עם סימן, שלילת מספר שלם, bitcast בחזרה למספר שלם ללא סימן.
- למספרים ממשיים:
negateמ-IEEE-754. - למספרים מרוכבים: שלילה מרוכבת.
- לסוגים כמותיים:
dequantize_op_quantize(negate, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
לא
סמנטיקה
מבצעת פעולת NOT על כל רכיב בטנזור operand ומפיקה טנזור result.
בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: שלילה לוגית.
- למספרים שלמים: NOT ברמת הסיביות.
ארגומנטים
| שם | סוג | מגבלות |
|---|---|---|
operand |
טנסור מסוג בוליאני או מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג בוליאני או מספר שלם | (C1) |
מגבלות
- (C1)
type(operand) = type(result).
דוגמאות
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
סמנטיקה
התכונה הזו מבטיחה שהפעולות שמייצרות את operand יבוצעו לפני כל הפעולות שתלויות ב-result, ומונעת מהטרנספורמציות של הקומפיילר להעביר פעולות מעבר למחסום. במקרים אחרים, הפעולה היא זהות, כלומר result = operand.
ארגומנטים
| שם | סוג | מגבלות |
|---|---|---|
operand |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה לכל טנסור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה לכל טנסור | (C1) |
מגבלות
- (C1)
type(operand...) = type(result...).
דוגמאות
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
או
סמנטיקה
מבצעת פעולת OR ברמת הרכיב של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- לערכים בוליאניים: OR לוגי.
- למספרים שלמים: ערך OR ברמת הסיביות.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור מסוג מספר שלם או בוליאני | (C1) |
| (I2) | rhs |
טנזור מסוג מספר שלם או בוליאני | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם או בוליאני | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
outfeed
סמנטיקה
כותב inputs ל-outfeed ויוצר אסימון result.
הסמנטיקה של outfeed_config מוגדרת בהטמעה.
קלט
| תווית | שם | סוג |
|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה |
| (I2) | token |
token |
| (I3) | outfeed_config |
קבוע מהסוג string |
פלט
| שם | סוג |
|---|---|
result |
token |
דוגמאות
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
pad
סמנטיקה
מרחיב את operand על ידי הוספת ריפוד סביב הטנזור וגם בין הרכיבים של הטנזור עם הערך padding_value שצוין.
הערכים edge_padding_low ו-edge_padding_high מציינים את כמות הריווח שנוספה בקצה הנמוך (ליד אינדקס 0) ובקצה הגבוה (ליד האינדקס הגבוה ביותר) של כל מימד, בהתאמה. הערך של הריווח יכול להיות שלילי. הערך המוחלט של ריווח שלילי מציין את מספר הרכיבים שיוסרו מהמאפיין שצוין.
interior_padding מציין את כמות הריווח שנוספת בין שני רכיבים בכל מימד, והיא לא יכולה להיות שלילית. ריווח פנימי מתרחש לפני ריווח שוליים, כך שריווח שוליים שלילי יסיר אלמנטים מהאופרנד עם הריווח הפנימי.
באופן רשמי יותר, result[result_index] מוגדר כך:
-
operand[operand_index]ifresult_index = edge_padding_low + operand_index * (interior_padding + 1). padding_valueאחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
טנזור אפס-ממדי או טנזור שעבר קוונטיזציה ברמת הטנזור | (C1) |
| (I3) | edge_padding_low |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
קבוע טנסור חד-ממדי מסוג si64 |
(C1), (C4) |
| (I5) | interior_padding |
קבוע טנסור חד-ממדי מסוג si64 |
(C2-C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C3-C6) |
מגבלות
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
דוגמאות
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
סמנטיקה
יוצרת partition_id של התהליך הנוכחי.
פלט
| שם | סוג |
|---|---|
result |
טנזור אפס-ממדי מסוג ui32 |
דוגמאות
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
סמנטיקה
מבצעת ספירה של מספר הביטים שמוגדרים בטנזור operand לפי רכיב, ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם | (C1) |
מגבלות
- (C1)
type(operand) = type(result).
דוגמאות
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
כוח
סמנטיקה
מבצע העלאה בחזקה של כל אחד מהרכיבים בטנזור lhs בטנזור rhs ומפיק טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים שלמים: העלאה בחזקה של מספר שלם.
- למספרים ממשיים:
powמ-IEEE-754. - למספרים מרוכבים: העלאה בחזקה של מספר מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(power, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
| (I2) | rhs |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
אמיתי
סמנטיקה
הפונקציה מחלצת את החלק הממשי, לפי רכיב, מ-operand ומפיקה טנסור result. באופן רשמי יותר, לכל רכיב x:
real(x) = is_complex(x) ? real_part(x) : x.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מסוג מרוכב | (C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג נקודה צפה | (C1), (C2) |
מגבלות
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)מוגדר כך:-
complex_element_type(element_type(operand))ifis_complex(operand). element_type(operand)אחרת.
-
דוגמאות
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
recv
סמנטיקה
מקבל נתונים מערוץ עם channel_id ומפיק results.
אם is_host_transfer הוא true, הפעולה מעבירה נתונים מהמארח. אחרת, הנתונים מועברים ממכשיר אחר על סמך הערכים של source_target_pairs. הסימון הזה משכפל את המידע שמופיע ב-channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (מס' 666). אם is_host_transfer
= false ו-source_target_pairs הוא None או ריק, ההתנהגות נחשבת ללא מוגדרת.
results מורכב מערכי מטען ייעודי (payload) שמופיעים ראשונים וטוקן שמופיע אחרון. בעתיד, אנחנו מתכננים לפצל את מטען הנתונים ואת האסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
קבוע מהסוג si64 |
|
| (I3) | channel_type |
enum של DEVICE_TO_DEVICE ו-DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
קבוע מהסוג i1 |
(C5-C6) |
| (I5) | source_target_pairs |
קבוע טנסור דו-ממדי מסוג si64 |
(C1-C4), (C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים, טנסורים או טוקנים שעברו קוונטיזציה | (C2-C4) |
מגבלות
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, כאשרNמוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_partitionsאם משתמשים ב-cross_partition.
-
- (C5)
channel_typeמוגדר כך:DEVICE_TO_HOSTifis_host_transfer = true,DEVICE_TO_DEVICEאחרת.
דוגמאות
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
הקטנה
סמנטיקה
מחיל פונקציית צמצום body על inputs ועל init_values לאורך dimensions ומפיק טנסורים של results.
סדר הפעולות מוגדר בהטמעה, כלומר body ו-init_values חייבים ליצור מונואיד כדי להבטיח שהפעולה תפיק את אותן תוצאות לכל הקלטים בכל ההטמעות. עם זאת, התנאי הזה לא מתקיים בהרבה צמצומים פופולריים. לדוגמה, פעולת החיבור של מספרים ממשיים ב-body ושל אפס ב-init_values לא יוצרת מונואיד, כי פעולת החיבור של מספרים ממשיים היא לא אסוציאטיבית.
באופן רשמי יותר, results...[j0, ..., jR-1] = reduce(input_slices_converted) כאשר:
-
input_slices = inputs...[j0, ..., :, ..., jR-1], כאשר:מוכנסים ב-dimensions. input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).-
reduce(input_slices_converted) = exec(schedule)עבור עץ בינארי מסוים scheduleכאשר:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
-
scheduleהוא עץ בינארי מלא שמוגדר בהטמעה, והמעבר שלו בסדר עולה כולל:- ערכים של
input_slices_converted...[index], לכלindexב-index_space(input_slices_converted)בסדר לקסיקוגרפי עולה שלindex. - במרווחים של כמות
init_values_convertedשמוגדרת בהטמעה, במיקומים שמוגדרים בהטמעה.
- ערכים של
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C4), (C6), (C7) |
| (I2) | init_values |
מספר משתנה של טנסורים אפסי-ממדיים או טנסורים שעברו קוונטיזציה לכל טנסור | (C2), (C3) |
| (I3) | dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C4), (C5), (C7) |
| (I4) | body |
פונקציה | (C6) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C3), (C7), (C8) |
מגבלות
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodyהוא מסוג(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)כאשרis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...)למעט גדלי המאפיינים שלinputs...שמתאימים ל-dimensions, שלא נכללים. - (C8)
element_type(results[i]) = Eiלכלiב-[0,N).
דוגמאות
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
סמנטיקה
מבצע המרה של operand לאלמנטים לסוג אחר של נקודה צפה שמשתמש ב-exponent_bits וב-mantissa_bits, ואז בחזרה לסוג המקורי של נקודה צפה, ומפיק טנסור output.
באופן רשמי יותר:
- ביטים של המנטיסה של הערך המקורי מעודכנים כדי לעגל את הערך המקורי לערך הקרוב ביותר שאפשר לייצג באמצעות
mantissa_bitsעם סמנטיקה שלroundToIntegralTiesToEven. - לאחר מכן, אם
mantissa_bitsקטן ממספר הביטים של המנטיסה של הערך המקורי, הביטים של המנטיסה נחתכים ל-mantissa_bits. - לאחר מכן, אם הביטים של המעריך בתוצאת הביניים לא מתאימים לטווח שצוין על ידי
exponent_bits, מתרחש גלישה בתוצאת הביניים לאינסוף באמצעות הסימן המקורי, או גלישה תחתונה לאפס באמצעות הסימן המקורי. - לסוגים כמותיים, הפונקציה מבצעת
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
| (I2) | exponent_bits |
קבוע מהסוג si32 |
(C2) |
| (I3) | mantissa_bits |
קבוע מהסוג si32 |
(C3) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
output |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
דוגמאות
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
סמנטיקה
בכל קבוצת תהליכים ברשת התהליכים של StableHLO, הפונקציה מבצעת צמצום באמצעות computations על הערכים של טנסור operand מכל תהליך, מפצלת את תוצאת הצמצום לאורך scatter_dimension לחלקים, ומפיצה את החלקים המפוצלים בין התהליכים כדי ליצור את result.
הפעולה מפצלת את רשת התהליכים של StableHLO ל-process_groups, שמוגדרת כך:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
לאחר מכן, בכל process_group:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).-
result@receiver = parts@sender[receiver_index]לכלsenderב-process_group, כאשרreceiver_index = process_group.index(receiver).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
קבוע מהסוג si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
קבוע טנסור דו-ממדי מסוג si64 |
(C3-C5) |
| (I4) | channel_id |
קבוע מהסוג si64 |
(C6) |
| (I5) | use_global_device_ids |
קבוע מהסוג i1 |
(C6) |
| (I6) | computation |
פונקציה | (C7) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C8-C9) |
מגבלות
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)מוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_replicasאם משתמשים ב-cross_replica_and_partition. -
num_processesאם משתמשים ב-flattened_ids.
-
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) אם
use_global_device_ids = true, אזchannel_id > 0. - (C7)
computationהוא מסוג(tensor<E>, tensor<E>) -> (tensor<E>)כאשרis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand)except:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
דוגמאות
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
סמנטיקה
מחיל פונקציית צמצום body על חלונות של inputs ו-init_values
ומפיק results.
התרשים הבא מציג דוגמה קונקרטית שממחישה איך מחשבים את הרכיבים ב-results... מתוך inputs....
באופן רשמי יותר,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(ראו reduce) כאשר:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
מספר משתנה של טנסורים אפסי-ממדיים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1), (C13) |
| (I3) | window_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C4), (C5), (C15) |
| (I4) | window_strides |
קבוע טנסור חד-ממדי מסוג si64 |
(C6), (C7), (C15) |
| (I5) | base_dilations |
קבוע טנסור חד-ממדי מסוג si64 |
(C8), (C9), (C15) |
| (I6) | window_dilations |
קבוע טנסור חד-ממדי מסוג si64 |
(C10), (C11), (C15) |
| (I7) | padding |
קבוע טנסור דו-ממדי מסוג si64 |
(C12), (C15) |
| (I8) | body |
פונקציה | (C13) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1), (C14-C16) |
מגבלות
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodyהוא מסוג(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)כאשרis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsכאשר:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eiלכלiב-[0,N).
דוגמאות
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
שארית
סמנטיקה
הפונקציה מחשבת את השארית של כל אחד מהרכיבים של טנזור המחולק lhs וטנזור המחלק rhs, ומחזירה טנזור result.
באופן רשמי יותר, הסימן של התוצאה נלקח מהמחולק, והערך המוחלט של התוצאה תמיד קטן מהערך המוחלט של המחלק.
השארית מחושבת כך: lhs - d * rhs, כאשר d מחושב כך:
- למספרים שלמים:
stablehlo.divide(lhs, rhs). - למספרים ממשיים:
division(lhs, rhs)מ-IEEE-754 עם מאפיין עיגולroundTowardZero. - למספרים מרוכבים: ייקבע בהמשך (#997).
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
במקרה של סוגי רכיבים של נקודה צפה, הפעולה הזו שונה מהפעולה remainder שמוגדרת במפרט IEEE-754, שבה d הוא ערך שלם הכי קרוב לערך המדויק של lhs/rhs, וההכרעה במקרה של שוויון היא לטובת המספר הזוגי.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
| (I2) | rhs |
טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים ממשיים או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
סמנטיקה
יוצרת replica_id של התהליך הנוכחי.
פלט
| שם | סוג |
|---|---|
result |
טנזור אפס-ממדי מסוג ui32 |
דוגמאות
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
עיצוב מחדש
סמנטיקה
מבצעת שינוי צורה של טנסור operand לטנסור result. מבחינה רעיונית, המשמעות היא לשמור על אותו ייצוג קנוני אבל לשנות את הצורה, למשל מ-tensor<2x3xf32> ל-tensor<3x2xf32> או ל-tensor<6xf32>.
באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר
result_index ו-operand_index נמצאים באותו מיקום בסדר הלקסיקוגרפי של index_space(result) ו-index_space(operand).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C3) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1-C3) |
מגבלות
- (C1)
element_type(result)מחושב לפי הנוסחה הבאה:-
element_type(operand), אם!is_per_axis_quantized(operand). element_type(operand)חוץ מהמקרים שבהםquantization_dimension(operand)וquantization_dimension(result)שונים.
-
- (C2)
size(operand) = size(result). - (C3) אם
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
דוגמאות
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
הפוך
סמנטיקה
הפונקציה הופכת את סדר הרכיבים ב-operand לאורך dimensions שצוין
ומפיקה טנזור result. באופן רשמי יותר,
result[result_index] = operand[operand_index] כאשר:
operand_index[d] = dim(result, d) - result_index[d] - 1ifdindimensions.operand_index[d] = result_index[d]אחרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1), (C3) |
| (I2) | dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C3) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1), (C3) |
מגבלות
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
דוגמאות
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
סמנטיקה
יוצרת מספרים אקראיים באמצעות אלגוריתם rng_distribution ומפיקה טנזור result בצורה נתונה shape.
אם rng_distribution = UNIFORM, המספרים האקראיים נוצרים לפי התפלגות אחידה במרווח [a, b). אם a >= b,
ההתנהגות לא מוגדרת.
אם הערך הוא rng_distribution = NORMAL, המספרים האקראיים נוצרים לפי ההתפלגות הנורמלית עם ממוצע = a וסטיית תקן = b.
אם b < 0, ההתנהגות לא מוגדרת.
הדרך המדויקת שבה נוצרים מספרים אקראיים מוגדרת בהטמעה. לדוגמה, יכול להיות שהם דטרמיניסטיים או לא, ויכול להיות שהם משתמשים במצב מוסתר או לא.
בשיחות עם בעלי עניין רבים, עלה שהאופציה הזו הוצאה משימוש בפועל, ולכן בעתיד אנחנו מתכננים לבדוק את האפשרות להסיר אותה (#597).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | a |
טנזור אפס-ממדי מסוג מספר שלם, בוליאני או נקודה צפה | (C1), (C2) |
| (I2) | b |
טנזור אפס-ממדי מסוג מספר שלם, בוליאני או נקודה צפה | (C1), (C2) |
| (I3) | shape |
קבוע טנסור חד-ממדי מסוג si64 |
(C3) |
| (I4) | rng_distribution |
enum of UNIFORM and NORMAL |
(C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם, בוליאני או נקודה צפה | (C1-C3) |
מגבלות
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) אם
rng_distribution = NORMAL, אזis_float(a). - (C3)
shape(result) = shape.
דוגמאות
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
סמנטיקה
הפונקציה מחזירה output מלא בביטים אקראיים אחידים ומצב פלט מעודכן output_state באמצעות האלגוריתם של מחולל מספרים פסאודו-אקראיים rng_algorithm, בהינתן מצב התחלתי initial_state. הפלט הוא פונקציה דטרמיניסטית של initial_state, אבל לא מובטח שהוא יהיה דטרמיניסטי בין יישומים שונים.
rng_algorithm הוא אחד מהבאים:
-
DEFAULT: אלגוריתם שמוגדר בהטמעה. -
THREE_FRY: וריאנט של אלגוריתם Threefry שמוגדר בהטמעה.* -
PHILOX: וריאציה של אלגוריתם Philox שמוגדרת בהטמעה.*
* ראו: Salmon et al. SC 2011. מספרים אקראיים מקבילים: קל ופשוט כמו 1, 2, 3.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | rng_algorithm |
enum של DEFAULT, THREE_FRY ו-PHILOX |
(C2) |
| (I2) | initial_state |
טנזור חד-ממדי מסוג ui64 |
(C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
output_state |
טנזור חד-ממדי מסוג ui64 |
(C1) |
output |
טנסור מסוג מספר שלם או מספר בשיטת נקודה צפה |
מגבלות
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)מוגדר כך:- מוגדרת בהטמעה אם
rng_algorithm = DEFAULT. -
2ifrng_algorithm = THREE_FRY. -
2או3אםrng_algorithm = PHILOX.
- מוגדרת בהטמעה אם
דוגמאות
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
סמנטיקה
מעגלת כל רכיב בטנזור operand למספר השלם הקרוב ביותר, ומעגלת מספרים שהם בדיוק באמצע בין שני מספרים שלמים לכיוון הרחוק מאפס. התוצאה היא טנזור result. מטמיע את הפעולה roundToIntegralTiesToAway מהמפרט IEEE-754. עבור סוגים שעברו קוונטיזציה, הפעולה dequantize_op_quantize(round_nearest_afz, operand, type(result)) מתבצעת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
סמנטיקה
הפונקציה מבצעת עיגול של כל רכיב בטנזור operand למספר השלם הקרוב ביותר, ומעגלת מספרים שהם בדיוק באמצע בין שני מספרים שלמים למספר השלם הזוגי. התוצאה היא טנזור result. מיישמת את הפעולה roundToIntegralTiesToEven מהמפרט IEEE-754. לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(round_nearest_even, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
סמנטיקה
מבצעת פעולת שורש ריבועי הדדי על טנסור operand ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
rSqrtמ-IEEE-754. - למספרים מרוכבים: שורש ריבועי הפוך מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(rsqrt, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
פיזור
סמנטיקה
הפונקציה יוצרת טנסורים results ששווים לטנסורים inputs, אלא שכמה פרוסות שצוינו על ידי scatter_indices מתעדכנות עם הערכים updates באמצעות update_computation.
בתרשים הבא מוצג מיפוי של רכיבים ב-updates... לרכיבים ב-results... באמצעות דוגמה קונקרטית. בתרשים מוצגים כמה מדדים לדוגמה של updates... ומפורט לאילו מדדים של results... הם מתאימים.
באופן רשמי יותר, לכל update_index ב-index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].- ההגדרה של
start_indexהיא:-
scatter_indices[si0, ..., :, ..., siN]כאשרsiהם רכיבים נפרדים ב-update_scatter_indexו-:מוכנס באינדקסindex_vector_dim, אםindex_vector_dim<rank(scatter_indices). [scatter_indices[update_scatter_index]]אחרת.
-
-
d_inputב-axes(inputs[0]),-
full_start_index[d_input] = start_index[d_start]ifd_input = scatter_dims_to_operand_dims[d_start]. full_start_index[d_input] = 0אחרת.
-
-
d_inputב-axes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]אםd_input = input_batching_dims[i_batching]וגםd_start = scatter_indices_batching_dims[i_batching].full_batching_index[d_input] = 0אחרת.
update_window_index = update_index[update_window_dims...].-
full_window_index = [wi0, ..., 0, ..., wiN]כאשרwiהם רכיבים נפרדים ב-update_window_index, ו-0מוכנס באינדקסים מ-inserted_window_dimsועדinput_batching_dims. result_index = full_start_index + full_batching_index + full_window_index.
בהתאם לכך, results = exec(schedule, inputs), כאשר:
-
scheduleהוא תמורה שמוגדרת בהטמעה שלindex_space(updates[0]). exec([update_index, ...], results) = exec([...], updated_results)כאשר:- אם
result_indexנמצא בטווח שלshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)-
updated_resultsהוא עותק שלresultsעםresults...[result_index]שמוגדר כ-updated_values.... - אחרת
updated_results = results.
- אם
exec([], results) = results.
אם indices_are_sorted הוא true, ההטמעה יכולה להניח ש-scatter_indices ממוינים לפי scatter_dims_to_operand_dims, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל i1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).
אם הערך של unique_indices הוא true, ההטמעה יכולה להניח שכל האינדקסים של result_index שמופצים הם ייחודיים. אם הערך של unique_indices הוא true אבל האינדקסים שאליהם מתבצע הפיזור לא ייחודיים, ההתנהגות לא מוגדרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
טנזור מסוג מספר שלם | (C4), (C15), (C19), (C22) |
| (I3) | updates |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C3-C6), (C8) |
| (I4) | update_window_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
קבוע טנסור חד-ממדי מסוג si64 |
(C19-C21) |
| (I9) | index_vector_dim |
קבוע מהסוג si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
קבוע מהסוג i1 |
|
| (I11) | unique_indices |
קבוע מהסוג i1 |
|
| (I12) | update_computation |
פונקציה | (C23) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C24-C25) |
מגבלות
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)כאשר:-
update_scatter_dim_sizes = shape(scatter_indices), אלא שגודל המימד שלscatter_indicesשמתאים ל-index_vector_dimלא נכלל. -
update_window_dim_sizes <= shape(inputs[0]), רק שגודלי המאפיינים ב-inputs[0]שמתאימים ל-inserted_window_dimsול-input_batching_dimsלא נכללים. - הפונקציה
combineממקמת אתupdate_scatter_dim_sizesבצירים שתואמים ל-update_scatter_dimsואתupdate_window_dim_sizesבצירים שתואמים ל-update_window_dims.
-
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationהוא מסוג(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), כאשרis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eiלכלiב-[0,N).
דוגמאות
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
בחירה
סמנטיקה
הפונקציה יוצרת טנסור result שבו כל רכיב נבחר מתוך טנסור on_true או on_false על סמך הערך של הרכיב התואם בטנסור pred.
באופן רשמי יותר, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], כאשר pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. לסוגים כמותיים, הפונקציה מבצעת dequantize_select_quantize(pred, on_true, on_false, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | pred |
טנסור מסוג i1 |
(C1) |
| (I2) | on_true |
tensor או per-tensor quantized tensor | (C1-C2) |
| (I3) | on_false |
tensor או per-tensor quantized tensor | (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C2) |
מגבלות
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
דוגמאות
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
סמנטיקה
הפעולה מפזרת את הערכים מטנסור source באמצעות scatter על סמך התוצאה של reduce_window של טנסור input באמצעות select, ומפיקה טנסור result.
בתרשים הבא מוצגות דוגמאות קונקרטיות שממחישות איך מחשבים את הרכיבים ב-result מ-operand ומ-source.
באופן רשמי יותר:
selected_values = reduce_window_without_init(...)עם הקלט הבא:inputs = [operand].-
window_dimensions, window_stridesו-paddingשמשמשים כמו שהם. base_dilations = windows_dilations = 1.- ההגדרה של
bodyהיא:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;כאשר
E = element_type(operand)ו-reduce_window_without_initפועלים בדיוק כמוreduce_window, אלא ש-scheduleשלreduceהבסיסי (ראו reduce) לא כולל ערכי init. בשלב הזה לא ברור מה קורה אם בחלון המתאים אין ערכים (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)where:source_values = [source[source_index] for source_index in source_indices].-
selected_index(source_index) = operand_indexif selected_values[source_index]has theoperandelement fromoperand_index. source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tensor או per-tensor quantized tensor | (C1), (C2) |
| (I3) | init_value |
טנזור אפס-ממדי או טנזור שעבר קוונטיזציה לכל טנזור | (C3) |
| (I4) | window_dimensions |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4), (C5) |
| (I5) | window_strides |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C6), (C7) |
| (I6) | padding |
קבוע טנסור דו-ממדי מסוג si64 |
(C2), (C8) |
| (I7) | select |
פונקציה | (C9) |
| (I8) | scatter |
פונקציה | (C10) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C11-C12) |
מגבלות
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windowsכאשר:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selectהוא מסוג(tensor<E>, tensor<E>) -> tensor<i1>כאשרE = element_type(operand). - (C10)
scatterהוא מסוג(tensor<E>, tensor<E>) -> tensor<E>כאשרis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
דוגמאות
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
שליחה
סמנטיקה
שליחת inputs לערוץ channel_id. הקלט נשלח למכשירים אחרים לפי הסדר שצוין ב-source_target_pairs. הפעולה יוצרת טוקן result.
אם is_host_transfer הוא true, הפעולה מעבירה נתונים למארח. אחרת, הנתונים מועברים למכשיר אחר על סמך הערכים של source_target_pairs. הסימון הזה משכפל את המידע שמופיע ב-channel_type, ולכן בעתיד אנחנו מתכננים להשאיר רק אחד מהם (מס' 666). אם is_host_transfer
= false ו-source_target_pairs הוא None או ריק, ההתנהגות נחשבת ללא מוגדרת.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה | |
| (I2) | token |
token |
|
| (I3) | channel_id |
קבוע מהסוג si64 |
|
| (I4) | channel_type |
enum של DEVICE_TO_DEVICE ו-DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
קבוע מהסוג i1 |
(C5-C6) |
| (I6) | source_target_pairs |
קבוע טנסור דו-ממדי מסוג si64 |
(C1-C4), (C6) |
פלט
| שם | סוג |
|---|---|
result |
token |
מגבלות
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, כאשרNמוגדר כך:-
num_replicasאם משתמשים ב-cross_replica. -
num_partitionsאם משתמשים ב-cross_partition.
-
- (C5)
channel_typeמוגדר כך:DEVICE_TO_HOSTifis_host_transfer = true,DEVICE_TO_DEVICEאחרת.
דוגמאות
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
סמנטיקה
מבצעת פעולת הזזה שמאלה ברמת הרכיב בטנזור lhs במספר rhs של סיביות ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור מסוג מספר שלם | (C1) |
| (I2) | rhs |
טנזור מסוג מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
סמנטיקה
מבצעת פעולת הזזה אריתמטית ימינה ברמת הרכיב בטנזור lhs לפי מספר הסיביות rhs ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור מסוג מספר שלם | (C1) |
| (I2) | rhs |
טנזור מסוג מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
סמנטיקה
מבצעת פעולת הזזה לוגית ימינה ברמת הרכיב בטנזור lhs במספר rhs של סיביות ומפיקה טנזור result.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור מסוג מספר שלם | (C1) |
| (I2) | rhs |
טנזור מסוג מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג מספר שלם | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
סמל
סמנטיקה
הפונקציה מחזירה את הסימן של כל רכיב בטנזור operand ומפיקה טנזור result.
באופן רשמי יותר, אפשר לבטא את הסמנטיקה של כל רכיב x באמצעות תחביר Python באופן הבא:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(sign, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספר שלם עם סימן, מספר נקודה צפה או סוג מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
סינוס
סמנטיקה
מבצעת פעולת סינוס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
sinמ-IEEE-754. - למספרים מרוכבים: סינוס מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(sine, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
פרוסה (slice)
סמנטיקה
מחזירה פרוסה מ-operand באמצעות אינדקסים התחלתיים שמחושבים באופן סטטי ומפיקה טנזור result. start_indices מכיל את אינדקס ההתחלה של הפרוסה לכל מאפיין, limit_indices מכיל את אינדקס הסיום (לא כולל) של הפרוסה לכל מאפיין, ו-strides מכיל את הצעדים לכל מאפיין.
באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר
operand_index = start_indices + result_index * strides.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
tensor או per-tensor quantized tensor | (C1-C3), (C5) |
| (I2) | start_indices |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C3), (C5) |
| (I3) | limit_indices |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C3), (C5) |
| (I4) | strides |
קבוע טנסור חד-ממדי מסוג si64 |
(C2), (C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tensor או per-tensor quantized tensor | (C1), (C5) |
מגבלות
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
דוגמאות
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
מיון
סמנטיקה
ממיינת פרוסות חד-ממדיות של inputs לאורך המימד dimension יחד,
בהתאם ל-comparator ומפיקה results.
בניגוד לקלט דומה בפעולות אחרות, dimension מאפשרת ערכים שליליים,
עם הסמנטיקה שמתוארת בהמשך. יכול להיות שבעתיד לא נאפשר את זה מסיבות של עקביות (#1377).
אם is_stable הוא true, המיון יציב, כלומר הסדר היחסי של הרכיבים שנחשבים שווים על ידי פונקציית ההשוואה נשמר. במקרה שיש קלט יחיד, שני האלמנטים e1 ו-e2 נחשבים שווים על ידי הפונקציה להשוואה אם ורק אם comparator(e1, e2) = comparator(e2, e1) = false. בהמשך מוסבר איך להכליל את זה למספר קלטים.
באופן רשמי יותר, לכל result_index ב-index_space(results[0]):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.-
result_slice = [ri0, ..., :, ..., riR-1]כשriNהם רכיבים נפרדים ב-result_index, ו-:מוכנס ב-adjusted_dimension. inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- כאשר
sortממיין פרוסה חד-ממדית בסדר לא יורד, ומצפה ש-comparator_togetherיחזירtrueאם הארגומנט בצד ימין קטן מהארגומנט השני בצד שמאל. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | inputs |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C1-C5) |
| (I2) | dimension |
קבוע מהסוג si64 |
(C4) |
| (I3) | is_stable |
קבוע מהסוג i1 |
|
| (I4) | comparator |
פונקציה | (C5) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של טנסורים או טנסורים שעברו קוונטיזציה לכל טנסור | (C2), (C3) |
מגבלות
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, כשR = rank(inputs[0]). - (C5)
comparatorהוא מסוג(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, כאשרEi = element_type(inputs[i]).
דוגמאות
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
סמנטיקה
מבצעת פעולת שורש ריבועי על כל רכיב בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
squareRootמ-IEEE-754. - למספרים מרוכבים: שורש ריבועי מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(sqrt, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
חיסור
סמנטיקה
מבצעת חיסור של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים שלמים: חיסור מספרים שלמים.
- למספרים ממשיים:
subtractionמ-IEEE-754. - למספרים מרוכבים: חיסור מרוכב.
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
| (I2) | rhs |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור של מספרים שלמים, מספרים עם נקודה עשרונית או מספרים מרוכבים, או טנזור שעבר קוונטיזציה לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
דוגמאות
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
סמנטיקה
מבצעת פעולת טנגנס לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
tanמ-IEEE-754. - למספרים מרוכבים: טנגנס מרוכב.
- לסוגים כמותיים:
dequantize_op_quantize(tan, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
סמנטיקה
מבצעת פעולת טנגנס היפרבולי לפי רכיבים בטנזור operand ומפיקה טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- למספרים ממשיים:
tanhמ-IEEE-754. - למספרים מרוכבים: טנגנס היפרבולי מרוכב.
- לסוגים שעברו קוונטיזציה:
dequantize_op_quantize(tanh, operand, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_type(operand) = baseline_type(result).
דוגמאות
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
החלפה
סמנטיקה
מבצעת פרמוטציה של המימדים של טנזור operand באמצעות permutation ומפיקה טנזור result. באופן רשמי יותר, result[result_index] = operand[operand_index]
כאשר result_index[d] = operand_index[permutation[d]].
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנזור או טנזור שעבר קוונטיזציה | (C1-C4) |
| (I2) | permutation |
קבוע טנסור חד-ממדי מסוג si64 |
(C2-C4) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור או טנזור שעבר קוונטיזציה | (C1), (C3-C4) |
מגבלות
- (C1)
element_type(result)מחושב לפי הנוסחה הבאה:-
element_type(operand), אם!is_per_axis_quantized(operand). element_type(operand)חוץ מהמקרים שבהםquantization_dimension(operand)וquantization_dimension(result)שונים.
-
- (C2)
permutationהיא תמורה שלrange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) אם
is_per_axis_quantized(result), אזquantization_dimension(operand) = permutation(quantization_dimension(result)).
דוגמאות
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
סמנטיקה
פותרים קבוצות של מערכות של משוואות ליניאריות עם מטריצות מקדמים משולשות עליונות או תחתונות.
באופן פורמלי יותר, בהינתן a ו-b, result[i0, ..., iR-3, :, :] הוא הפתרון של op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] כאשר left_side הוא true או x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] כאשר left_side הוא false, ופתרון המשתנה x כאשר op(a) נקבע על ידי transpose_a, שיכול להיות אחד מהבאים:
-
NO_TRANSPOSE: ביצוע הפעולה באמצעותaכמו שהוא. -
TRANSPOSE: ביצוע פעולה על טרנספוזיציה שלa. -
ADJOINT: ביצוע פעולה על המטריצה המשוחפת שלa.
נתוני הקלט נקראים רק מהמשולש התחתון של a, אם lower הוא true או מהמשולש העליון של a, אחרת. נתוני הפלט מוחזרים באותו משולש. הערכים במשולש השני מוגדרים על ידי ההטמעה.
אם unit_diagonal הוא true, ההטמעה יכולה להניח שהאלמנטים האלכסוניים של a שווים ל-1, אחרת ההתנהגות לא מוגדרת.
לסוגים כמותיים, הפונקציה מבצעת dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | a |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1-C3) |
| (I2) | b |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1-C4) |
| (I3) | left_side |
קבוע מהסוג i1 |
(C3) |
| (I4) | lower |
קבוע מהסוג i1 |
|
| (I5) | unit_diagonal |
קבוע מהסוג i1 |
|
| (I6) | transpose_a |
enum של NO_TRANSPOSE, TRANSPOSE ו-ADJOINT |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנזור מסוג נקודה צפה או מספר מרוכב, או טנזור כמותי לכל טנזור | (C1) |
מגבלות
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) הקשר בין
shape(a)ל-shape(b)מוגדר באופן הבא:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
דוגמאות
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
סמנטיקה
הפונקציה יוצרת טאפל result מהערכים val.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | val |
מספר משתנה של ערכים | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
tuple | (C1) |
מגבלות
- (C1)
resultהוא מסוגtuple<E0, ..., EN-1>כאשרEi = type(val[i]).
דוגמאות
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
סמנטיקה
הפונקציה מבצעת המרה של טנסור עם כימות operand לטנסור עם נקודה צפה result לפי פרמטרי הכימות שמוגדרים בסוג operand.
באופן רשמי יותר, result = dequantize(operand).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנסור שעבר קוונטיזציה | (C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג נקודה צפה | (C1), (C2) |
מגבלות
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
דוגמאות
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
סמנטיקה
מבצע המרה של טנזור של נקודה צפה או טנזור שעבר קוונטיזציה operand לטנזור שעבר קוונטיזציה result לפי פרמטרי הקוונטיזציה שהוגדרו על ידי הסוג result.
באופן רשמי יותר,
- אם
is_float(operand):result = quantize(operand, type(result)).
- אם
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
טנסור מסוג נקודה צפה או מסוג כמותי | (C1), (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור שעבר קוונטיזציה | (C1), (C2) |
מגבלות
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
דוגמאות
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
בזמן ש
סמנטיקה
מפיקה את הפלט מהפעלת הפונקציה body 0 או יותר פעמים בזמן שהפונקציה cond מפיקה את הפלט true. באופן רשמי יותר, אפשר לבטא את הסמנטיקה באמצעות תחביר Python באופן הבא:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
ההתנהגות של לולאה אינסופית תיקבע בהמשך (#383).
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | operand |
מספר משתנה של ערכים | (C1-C3) |
| (I2) | cond |
פונקציה | (C1) |
| (I3) | body |
פונקציה | (C2) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
results |
מספר משתנה של ערכים | (C3) |
מגבלות
- (C1)
condהוא מסוג(T0, ..., TN-1) -> tensor<i1>, כאשרTi = type(operand[i]). - (C2)
bodyהוא מסוג(T0, ..., TN-1) -> (T0, ..., TN-1), כאשרTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
דוגמאות
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
סמנטיקה
מבצעת XOR ברמת הרכיבים של שני טנסורים, lhs ו-rhs, ומפיקה טנסור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:
- עבור ערכים בוליאניים: XOR לוגי.
- למספרים שלמים: ערך XOR ברמת הביטים.
קלט
| תווית | שם | סוג | מגבלות |
|---|---|---|---|
| (I1) | lhs |
טנסור מסוג בוליאני או מספר שלם | (C1) |
| (I2) | rhs |
טנסור מסוג בוליאני או מספר שלם | (C1) |
פלט
| שם | סוג | מגבלות |
|---|---|---|
result |
טנסור מסוג בוליאני או מספר שלם | (C1) |
מגבלות
- (C1)
type(lhs) = type(rhs) = type(result).
דוגמאות
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
פעולה הדדית בין ניבים
בשלב הזה, תוכניות StableHLO בשימוש נרחב מכילות לפעמים פעולות שלא מוגדרות על ידי StableHLO.
Module, Function, Call and Return
StableHLO משתמש בפעולות MLIR במעלה הזרם עבור ModuleOp, FuncOp, CallOp ו-ReturnOp. השינוי הזה בוצע כדי לשפר את יכולת הפעולה ההדדית עם מנגנון MLIR קיים, כי הרבה מעברים שימושיים נכתבים עם טירגוט של FuncOp ו-ModuleOp, והרבה צינורות של קומפילציה מצפים שהאופרטורים האלה יהיו נוכחים. התאימות המלאה מובטחת לפעולות האלה. אם יחולו שינויים בפעולות האלה באופן לא תואם (כלומר, הסרה), יתווספו מקבילות של StableHLO כדי לשמור על התאימות.
CHLO
קבוצת הפעולות של CHLO מכילה פעולות ברמה גבוהה יותר שמפורקות ל-StableHLO. נכון לעכשיו, אין התחייבות לתאימות של CHLO. כדי להבטיח תאימות, צריך להשתמש בהעברת chlo-legalize-to-stablehlo לפני הסריאליזציה.
פעולות על צורות
מקרה שימוש נפוץ בקהילה הוא שימוש בפעולות מסוימות מניבים בסיסיים של MLIR בתוכניות דינמיות של StableHLO כדי לבצע חישובים של צורות.
בדרך כלל, אלה כוללים פעולות של shape דיאלקט כמו shape_of או num_elements, פעולות של tensor דיאלקט כמו dim או from_elements, וסוג ה-build-in index.
במסמך Dynamism RFC > O2
הסוגים האלה מוגדרים כלא רלוונטיים, אבל יש תמיכה מסוימת בסוגים של index לצורך פעולה הדדית. אין התחייבות לתאימות של הפעולות או הסוגים האלה. אפשר להשתמש במעבר shape-legalize-to-stablehlo כדי להמיר את הפעולות האלה לפעולות StableHLO נתמכות באופן מלא.
פעולות שהוצאו משימוש
יש כמה פעולות StableHLO שהתקבלו בירושה מ-MHLO, שהוצאו משימוש ועומדות להיעלם מ-StableHLO. הפרטים המלאים על ההסרות האלה זמינים בניקוי StableHLO v1.0 #2283. מספר הבעיה בכלי המעקב לגבי הוצאה משימוש של התכונות האלה הוא 2340.
הפעולות האלה נחלקות לכמה קטגוריות:
- הקטגוריה 'לא ב-HLO' של פעולות StableHLO – הפעולות האלה היו במקור חלק מה-opset של StableHLO, אבל בהמשך נקבע שהן לא מתאימות לו:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(מס' 3). - פעולות לא בשימוש – יכול להיות שהפעולות האלה היו שימושיות בשלב מסוים, אבל הן לא פותחו מספיק או שהצינורות שמשתמשים בהן עברו שינוי מבנה כך שהן כבר לא נדרשות. השינויים כוללים את
map,tuple(#598), השוואותget_tuple_element,rng,complex(#560), וקונבולוציהwindow_reversal(#1181).
חלק מהפעולות האלה אפשר להסיר בקלות, כי אפשר לבטא אותן באמצעות פעולות קיימות (broadcast, create_token, cross-replica-sum, dot, unary_einsum), והן יוסרו אחרי שתקופת התאימות הקיימת תסתיים (6 חודשים). אנחנו עדיין בודקים אפשרות להסיר פעולות אחרות (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex השוואות, window_reversal). בהתאם למשוב מהקהילה, הפעולות האלה יוסרו או יתווספו למפרט עם תמיכה מלאה. עד שיידעו מה יהיו הגרסאות העתידיות האלה, מובטחת תאימות רק ל-6 חודשים.
ביצוע
ביצוע רציף
כדי להריץ תוכנית StableHLO, צריך לספק ערכי קלט לפונקציה main ולחשב את ערכי הפלט. ערכי הפלט של פונקציה מחושבים על ידי הפעלת הגרף של פעולות שמושרשות בפעולת return המתאימה.
סדר ההפעלה מוגדר בהטמעה, כל עוד הוא תואם לזרימת הנתונים, כלומר אם פעולות מופעלות לפני השימוש בהן. ב-StableHLO, כל פעולות ה-side-effecting צורכות טוקן אחד ומפיקות טוקן אחד (אפשר לבצע מולטיפלקס של כמה טוקנים לטוקן אחד באמצעות after_all), ולכן סדר הביצוע של side-effects תואם גם לזרימת הנתונים. לדוגמה, בתוכנית שלמטה יש שני סדרי ביצוע אפשריים: %0 ← %1 ← %2 ← return ו-%1 ← %0 ← %2 ← return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
באופן רשמי יותר, תהליך StableHLO הוא שילוב של:
1) תוכנית StableHLO, 2) סטטוסים של פעולות (עדיין לא בוצעו, כבר בוצעו) ו-3) ערכי ביניים שהתהליך פועל עליהם.
התהליך מתחיל בערכי קלט לפונקציה main, ממשיך בתרשים של פעולות שמעדכנות את סטטוסי הפעולות וערכי הביניים, ומסתיים בערכי פלט. המשך הפורמליזציה ייקבע בהמשך (#484).
ביצוע מקביל
אפשר להריץ תוכניות StableHLO במקביל, והן מאורגנות ברשת תהליכים דו-ממדית של num_replicas על num_partitions, ששניהם הם מסוג ui32.
ברשת התהליכים של StableHLO, num_replicas * num_partitions תהליכים של StableHLO
פועלים בו-זמנית. לכל תהליך יש process_id = (replica_id, partition_id) ייחודי, כאשר replica_id ב-replica_ids = range(num_replicas) ו-partition_id ב-partition_ids = range(num_partitions), ולשניהם יש סוג ui32.
גודל רשת התהליכים ידוע באופן סטטי לכל תוכנית (בעתיד, אנחנו מתכננים להפוך אותו לחלק מפורש בתוכניות StableHLO #650), והמיקום ברשת התהליכים ידוע באופן סטטי לכל תהליך. לכל תהליך יש גישה למיקום שלו ברשת התהליכים באמצעות הפעולות replica_id ו-partition_id.
ברשת התהליכים, כל התוכניות יכולות להיות זהות (בסגנון Single Program, Multiple Data), כולן יכולות להיות שונות (בסגנון Multiple Program, Multiple Data) או משהו באמצע. בעתיד, אנחנו מתכננים להוסיף תמיכה בניבים אחרים להגדרת תוכניות מקבילות של StableHLO, כולל GSPMD (#619).
בתוך רשת התהליכים, התהליכים הם ברובם בלתי תלויים זה בזה – יש להם סטטוסי פעולה נפרדים, ערכי קלט/ביניים/פלט נפרדים ורוב הפעולות מבוצעות בנפרד בין התהליכים, למעט מספר קטן של פעולות קולקטיביות שמתוארות בהמשך.
בהתחשב בכך שרוב הפעולות משתמשות רק בערכים מאותו תהליך, בדרך כלל ברור למה מתייחסים כשמציינים את השמות של הערכים האלה.
עם זאת, כשמתארים סמנטיקה של פעולות קולקטיביות, זה לא מספיק, ולכן משתמשים בסימון name@process_id כדי להתייחס לערך name בתהליך מסוים. (מנקודת המבט הזו, אפשר לראות ב-name לא מתאים קיצור של name@(replica_id(), partition_id())).
סדר הביצוע בין התהליכים מוגדר בהטמעה, למעט הסנכרון שנוצר על ידי תקשורת ישירה ופעולות קולקטיביות, כפי שמתואר בהמשך.
תקשורת מנקודה לנקודה
תהליכי StableHLO יכולים לתקשר זה עם זה באמצעות ערוצי StableHLO. ערוץ מיוצג על ידי מזהה חיובי מסוג si64. באמצעות פעולות שונות, אפשר לשלוח ערכים לערוצים ולקבל אותם מהערוצים.
צריך להגדיר באופן רשמי יותר, למשל מאיפה מגיעים מזהי הערוצים האלה, איך תוכנות תהליכים מודעות להם ואיזה סוג של סנכרון מוצג על ידם. הנושא הזה נמצא בבדיקה (#484).
שידור תקשורת
לכל תהליך StableHLO יש גישה לשני ממשקי סטרימינג:
- בפיד שאפשר לקרוא ממנו.
- Outfeed שאפשר לכתוב אליו.
בניגוד לערוצים, שמשמשים לתקשורת בין תהליכים ולכן יש להם תהליכים בשני הקצוות שלהם, ל-infeed ול-outfeed יש קצה אחר שמוגדר על ידי היישום.
המשך הפורמליזציה, למשל איך תקשורת בסטרימינג משפיעה על סדר הביצוע ואיזה סוג של סנכרון מוצג על ידי זה, יפורט בהמשך (#484).
תפעול קולקטיבי
יש שש פעולות קולקטיביות ב-StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute ו-reduce_scatter. כל הפעולות האלה מפצלות את התהליכים ברשת של תהליך StableHLO לקבוצות של תהליכי StableHLO ומבצעות חישוב משותף בכל קבוצת תהליכים, באופן עצמאי מקבוצות תהליכים אחרות.
בתוך כל קבוצת תהליכים, פעולות קולקטיביות עשויות ליצור מחסום סינכרון. צריך להוסיף עוד פרטים, למשל מתי בדיוק מתבצע הסנכרון הזה, איך בדיוק התהליכים מגיעים למחסום הזה ומה קורה אם הם לא מגיעים אליו. הפרטים האלה יתווספו בהמשך (#484).
אם קבוצת התהליכים כוללת תקשורת בין מחיצות, כלומר יש תהליכים בקבוצת התהליכים שמזהי המחיצות שלהם שונים, אז הביצוע של פעולת ה-Collective צריך ערוץ, ופעולת ה-Collective צריכה לספק channel_id חיובי מסוג si64. לא צריך ערוצים לתקשורת בין העתקים.
החישובים שמבוצעים על ידי הפעולות הקולקטיביות הם ספציפיים לפעולות בודדות, והם מתוארים בסעיפים של הפעולות הבודדות שלמעלה. עם זאת, השיטות שלפיהן רשת התהליכים מחולקת לקבוצות תהליכים משותפות בין הפעולות האלה ומתוארות בקטע הזה. באופן רשמי יותר, StableHLO תומך בארבע השיטות הבאות.
cross_replica
רק תקשורת בין רפליקות מתרחשת בתוך כל קבוצת תהליכים. האסטרטגיה הזו מקבלת את replica_groups – רשימה של רשימות של מזהי עותקים – ומחשבת את המכפלה הקרטזית של replica_groups ב-partition_ids. replica_groups
חייב להכיל רכיבים ייחודיים ולכלול את כל replica_ids. באופן רשמי יותר, באמצעות תחביר Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הפונקציה cross_replica תחזיר [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
רק תקשורת בין מחיצות מתרחשת בתוך כל קבוצת תהליכים. האסטרטגיה הזו מקבלת את partition_groups – רשימה של רשימות של מזהי מחיצות – ומחשבת את המכפלה הקרטזית של partition_groups ב-replica_ids.
partition_groups חייב להכיל רכיבים ייחודיים ולכלול את כל partition_ids.
באופן רשמי יותר, באמצעות תחביר Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
לדוגמה, עבור partition_groups = [[0, 1]] ו-num_replicas = 4, הפונקציה cross_partition תחזיר [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
תקשורת בין עותקים ותקשורת בין מחיצות יכולות להתרחש בכל קבוצת תהליכים. שיטת הבידינג הזו מקבלת replica_groups – רשימה של רשימות של מזהי עותקים – ומחשבת את המכפלה הקרטזית של כל replica_group לפי partition_ids. replica_groups חייב להכיל רכיבים ייחודיים ולכלול את כל replica_ids. באופן רשמי יותר, באמצעות תחביר Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הפונקציה cross_replica_and_partition תחזיר [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
שיטת הבידינג הזו מקבלת flattened_id_groups – רשימה של רשימות של מזהי תהליכים 'שטוחים' בצורה של replica_id * num_partitions + partition_id – והופכת אותם למזהי תהליכים. flattened_id_groups חייב להכיל רכיבים ייחודיים ולכלול את כל process_ids. באופן רשמי יותר, באמצעות תחביר Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
לדוגמה, עבור flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 ו-num_partitions = 2, הפונקציה flattened_ids תחזיר [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
דיוק
בשלב הזה, StableHLO לא מספקת ערבויות לגבי דיוק מספרי, אבל יכול להיות שזה ישתנה בעתיד (#1156).
סמנטיקה של ביצוע פעולה עם כימות
הפרשנות של פעולות כמותיות ב-StableHLO עשויה להשתנות בהתאם לדרישות וליכולות של החומרה. לדוגמה, יכול להיות שחלק מהחומרה יבחר לפרש פעולות קוונטיות באמצעות אסטרטגיה של 'ביטול קוונטיזציה, ביצוע פעולה של נקודה צפה ולבסוף קוונטיזציה'. אחרים עשויים לבצע את כל החישוב באמצעות אריתמטיקה של מספרים שלמים. לכן, הפרשנות של פעולות StableHLO שעברו קוונטיזציה נקבעת באופן בלעדי על ידי ההטמעה הספציפית. הפרשנות של הכמותיות ההיברידית (#1575) צריכה להתבסס על הסמנטיקה שלה כפי שמוגדר במפרט (דרך 1792).
שגיאות
תוכניות StableHLO עוברות אימות באמצעות קבוצה נרחבת של אילוצים עבור פעולות נפרדות, מה שמונע הרבה סוגים של שגיאות לפני זמן הריצה. עם זאת, עדיין יכולות להיות שגיאות, למשל בגלל הצפה של מספרים שלמים, גישה מחוץ לגבולות וכו'. אלא אם מצוין אחרת, כל השגיאות האלה גורמות להתנהגות שמוגדרת בהטמעה, אבל זה עשוי להשתנות בעתיד (#1157).
חריגים של נקודה צפה
יוצא מן הכלל לכלל הזה הוא חריג של נקודה צפה בתוכניות StableHLO, שבהן ההתנהגות מוגדרת היטב. פעולות שגורמות לחריגות שמוגדרות בתקן IEEE-754 (פעולה לא חוקית, חלוקה באפס, הצפה, חוסר קיבולת או חריגות לא מדויקות) מפיקות תוצאות ברירת מחדל (כפי שמוגדר בתקן) וממשיכות את הביצוע בלי להעלות את דגל הסטטוס המתאים. זה דומה לטיפול בחריגות raiseNoFlag מהתקן. חריגים לפעולות לא סטנדרטיות (למשל, אריתמטיקה מורכבת ופונקציות טרנסצנדנטיות מסוימות) מוגדרים בהטמעה.
אי התאמות בצורות
StableHLO תומך בטנסורים עם צורה דינמית. עם זאת, הצורות צריכות להיות זהות בזמן הריצה, אחרת ההתנהגות לא מוגדרת. StableHLO לא מספקת באופן מפורש אופרטור שיכול לקבוע שלטנסור יש צורה מסוימת בזמן הריצה. המפיק אחראי ליצירת קוד נכון.
לדוגמה, התוכנית הבאה תקינה. עם זאת, בזמן הריצה, הצורות המדויקות של %arg0 ושל %arg1 חייבות להיות זהות, אחרת ההתנהגות של התוכנית לא מוגדרת:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
סימון
כדי לתאר את התחביר, במסמך הזה נעשה שימוש בגרסה של תחביר EBNF לפי תקן ISO (ISO/IEC 14977:1996, Wikipedia) עם שני שינויים: 1) כללים מוגדרים באמצעות ::= ולא באמצעות =,
2) שרשור מבוטא באמצעות הצמדה ולא באמצעות ,.
כדי לתאר סמנטיקה (כלומר, בקטעים Types (סוגים), Constants (קבועים) ו-Ops (פעולות)), אנחנו משתמשים בנוסחאות שמבוססות על תחביר Python עם תמיכה מורחבת בביטוי תמציתי של פעולות על מערכים, כפי שמתואר בהמשך. השיטה הזו מתאימה לקטעי קוד קטנים, אבל במקרים נדירים שבהם נדרשים קטעי קוד גדולים יותר, אנחנו משתמשים בתחביר Python רגיל, שתמיד מוצג באופן מפורש.
נוסחאות
נבחן איך נוסחאות פועלות על סמך דוגמה מהdot_general
מפרט. אחת המגבלות של הפעולה הזו נראית כך:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
השמות שמשמשים בנוסחה הזו מגיעים משני מקורות: 1) פונקציות גלובליות, כלומר dim, 2) הגדרות של רכיב התוכנית המתאים, כלומר lhs, lhs_batching_dimensions, rhs וקלט rhs_batching_dimensions שמוגדרים בקטע Inputs (קלט) של dot_general.
כמו שצוין למעלה, התחביר של הנוסחה הזו מבוסס על Python עם כמה הרחבות שנועדו לפשט את הנוסחה. כדי להבין את הנוסחה, נמיר אותה לתחביר של Python רגיל.
א) בנוסחאות האלה, אנחנו משתמשים ב-= כדי לייצג שוויון, ולכן השלב הראשון לקבלת תחביר Python הוא להחליף את = ב-==, כמו בדוגמה הבאה:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
ב) הנוסחאות האלה תומכות גם בסימן שלוש הנקודות (...), שהופך ביטויים סקלריים לביטויים טנסוריים. בקיצור, f(xs...) בערך אומר "לכל סקלר x בטנזור xs, מחשבים סקלר f(x) ואז מחזירים את כל התוצאות הסקלריות האלה יחד כתוצאת טנזור". בתחביר Python רגיל, הנוסחה לדוגמה הופכת ל:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
בזכות סימני האליפסה, לעיתים קרובות אפשר להימנע מעבודה ברמה של סקלרים בודדים. עם זאת, במקרים מורכבים מסוימים, יכול להיות שנעשה שימוש בתחביר לא רשמי ברמה נמוכה יותר, כמו בנוסחה start_indices[bi0, ..., :, ..., biN] מהמפרט gather. כדי לשמור על תמציתיות, אנחנו לא מספקים פורמליזם מדויק לתרגום תחביר כזה ל-Python רגיל, בתקווה שעדיין אפשר להבין אותו באופן אינטואיטיבי על בסיס כל מקרה לגופו.
אם יש נוסחאות ספציפיות שנראות לכם לא ברורות, אתם יכולים להודיע לנו ואנחנו ננסה לשפר אותן.
בנוסף, תשימו לב שבנוסחאות נעשה שימוש בסימן שלוש הנקודות כדי להרחיב כל מיני רשימות, כולל טנסורים, רשימות של טנסורים (שיכולות להיווצר ממספר משתנה של טנסורים) וכו'. זה עוד תחום שבו אנחנו לא מספקים פורמליזם מדויק (למשל, רשימות הן לא חלק ממערכת הסוגים של StableHLO), אלא מסתמכים על מובנות אינטואיטיבית.
ג) כלי נוסף וחשוב שבו אנחנו משתמשים הוא שידור מרומז. למרות שערכת הפעולות (opset) של StableHLO לא תומכת בשידור מרומז, הנוסחאות כן תומכות, גם כדי לשמור על תמציתיות. בקיצור, אם נעשה שימוש בסקלר בהקשר שבו צפוי טנסור, הסקלר מועבר לשידור לצורה הצפויה.
כדי להמשיך את הדוגמה של dot_general, הנה עוד אילוץ:
0 <= lhs_batching_dimensions < rank(lhs). כפי שמוגדר במפרט dot_general, lhs_batching_dimensions הוא טנזור, אבל 0 ו-rank(lhs) הם סקלרים. אחרי שנחיל שידור מרומז, הנוסחה תהפוך ל-[0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
כשמחילים את הנוסחה הזו על פעולה מסוימת של dot_general, היא מחזירה טנזור של ערכים בוליאניים. כשמשתמשים בנוסחאות כאילוצים, האילוץ מתקיים אם הנוסחה מחזירה את הערך true או טנסור שמכיל רק רכיבים עם הערך true.
שמות
בנוסחאות, היקף לקסיקלי כולל: 1) פונקציות גלובליות, 2) הגדרות של חברים,
3) הגדרות מקומיות. בהמשך מופיעה רשימה של פונקציות גלובליות. רשימת ההגדרות של הרכיבים תלויה ברכיב התוכנית שאליו מוחל הסימון:
- בפעולות, ההגדרות של חברים כוללות שמות שמופיעים בקטעים 'קלט' ו'פלט'.
- בכל שאר המקרים, ההגדרות של חברי המועדון כוללות חלקים מבניים של רכיב התוכנית, שנקראים על שם המונחים הלא-סופיים המתאימים של EBNF. ברוב המקרים, השמות של החלקים המבניים האלה מתקבלים על ידי המרת השמות של הלא-טרמינלים ל-snake case (למשל,
IntegerLiteral=>integer_literal), אבל לפעמים השמות מקוצרים בתהליך (למשל,QuantizationStorageType=>storage_type). במקרה כזה, השמות מוצגים באופן מפורש בדומה לקטעים 'קלט' / 'פלט' במפרטים של פעולות. - בנוסף, הגדרות החברות תמיד כוללות את התו
selfכדי להפנות לרכיב התוכנית המתאים.
ערכים
כשמבצעים הערכה של נוסחאות, הן פועלות עם סוגי הערכים הבאים:
1) Value (ערכים בפועל, למשל dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; תמיד ידועים הסוגים שלהם),
2) Placeholder (ערכים עתידיים, למשל lhs, rhs או result; הערכים בפועל שלהם עדיין לא ידועים, רק הסוגים שלהם ידועים),
3) Type (סוגים כפי שהוגדרו בקטע 'סוגים'),
4) Function (פונקציות גלובליות כפי שהוגדרו בקטע 'פונקציות').
בהתאם להקשר, השמות עשויים להתייחס לערכים שונים. באופן ספציפי יותר, בקטע 'סמנטיקה' של פעולות (ומקבילות לרכיבי תוכנה אחרים) מוגדרת לוגיקה של זמן ריצה, כך שכל הקלט זמין כ-Value.
לעומת זאת, בקטע 'מגבלות' של פעולות (ומקבילות) מוגדרת לוגיקה של 'זמן קומפילציה', כלומר משהו שבדרך כלל מבוצע לפני זמן הריצה, ולכן רק קלטים קבועים זמינים כ-Value וקלטים אחרים זמינים רק כ-Placeholder.
| שמות | בקטע 'סמנטיקה' | בקטע 'מגבלות' |
|---|---|---|
| פונקציות גלובליות | Function |
Function |
| קלט קבוע | Value |
Value |
| קלט לא קבוע | Value |
Placeholder |
| פלט | Value |
Placeholder |
| הגדרות מקומיות | תלוי בהגדרה | תלוי בהגדרה |
נבחן עכשיו דוגמה לפעולה transpose:
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
בפעולה הזו, permutation הוא קבוע, ולכן הוא זמין כ-Value גם בסמנטיקה וגם באילוצים. לעומת זאת, operand ו-result זמינים כ-Value בסמנטיקה, אבל רק כ-Placeholder במגבלות.
פונקציות
בנייה של סוגים
אין פונקציות שאפשר להשתמש בהן כדי ליצור סוגים. במקום זאת, אנחנו משתמשים ישירות בתחביר של סוגים כי הוא בדרך כלל תמציתי יותר. לדוגמה:
(tensor<E>, tensor<E>) -> (tensor<E>) במקום function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
פונקציות לפי סוגים
- הפונקציה
element_typeמוגדרת בסוגי טנסורים ובסוגי טנסורים שעברו קוונטיזציה, ומחזירה את החלקTensorElementTypeאוQuantizedTensorElementTypeשלTensorTypeאוQuantizedTensorTypeהמתאימים, בהתאמה.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueהוא קיצור דרך ל-is_quantized(x) and quantization_dimension(x) is not None.
is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueהיא קיצור דרך לis_quantized(x) and quantization_dimension(x) is None.
is_promotable(x: Type, y: Type) -> boolבודק אם אפשר להמיר את הסוגxלסוגy. אםxו-yהםQuantizedTensorElementType, המבצע חל רק עלstorage_type. הגרסה הספציפית הזו של קידום מכירות משמשת כרגע בהקשר של חישוב הפחתה (פרטים נוספים זמינים ב-RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueהוא קיצור דרך ל-is_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. זמין לכל הסוגים. לדוגמה, הפונקציהis_float(x)מחזירהtrueאםxהואFloatType. אםxהוא ערך או placeholder, הפונקציה הזו היא קיצור שלis_type_name(type(x)).
max_value(x: Type) -> Valueמחזירה את הערך המקסימלי שלTensorElementType. אםxהוא לאTensorElementType, הפונקציה מחזירהNone.
min_value(x: Type) -> Valueמחזירה את הערך המינימלי האפשרי שלTensorElementType. אםxהוא לאTensorElementType, הפונקציה מחזירהNone.member_name(x: Value | Placeholder | Type) -> Any. זמין לכל הגדרות החבריםmember_nameמכל הסוגים. לדוגמה, הפקודהtensor_element_type(x)תחזיר את החלקTensorElementTypeשלTensorTypeתואם. אםxהוא ערך או placeholder, הפונקציה הזו היא קיצור שלmember_name(type(x)). אםxהוא לא סוג שיש לו חבר מתאים, או ערך או placeholder מסוג כזה, הפונקציה מחזירה אתx.None
is_empty_algorithm(*args: Type)בודקת אם כל השדות של אלגוריתם הנקודות מוגדרים ל-None. הסיבה לכך היא שלאלגוריתמים של נקודות יש התנהגויות ברירת מחדל שמוגדרות בהטמעה, ולכן ציון ערך ברירת מחדל יהיה שגוי.
בניית ערכים
-
operation_name(*xs: Value | Type) -> Value. זמין לכל הפעולות. לדוגמה, הפונקציהadd(lhs, rhs)מקבלת שני ערכי טנסורlhsו-rhsומחזירה את הפלט של הערכת הפעולהaddעם הקלטים האלה. לפעולות מסוימות, למשלbroadcast_in_dim, סוגי הפלט שלהן הם 'נושאי עומס', כלומר נדרשים להערכת פעולה. במקרה כזה, הפונקציה מקבלת את הסוגים האלה כארגומנטים.
פונקציות על ערכים
כל האופרטורים והפונקציות של Python זמינים. לדוגמה, אפשר להשתמש בסימון subscription וגם בסימון slicing מ-Python כדי ליצור אינדקס לטנסורים, לטנסורים שעברו קוונטיזציה ולטפלים.
to_destination_type(x: Value, destination_type: Type) -> Valueמוגדר בטנסורים ומחזיר את הערך המומר שלxעל סמךtype(x)ו-destination_typeבאופן הבא:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
מתנהל דיון ראשוני על מיזוג הפעולות convert, uniform_quantize ו-uniform_dequantize (#1576).
אחרי המיזוג, אין צורך בפונקציה שלמעלה ואפשר להשתמש בשם הפעולה convert במקום.
הפונקציה
is_nan(x: Value) -> Valueמוגדרת על טנסורים ומחזירהtrueאם כל האלמנטים שלxהםNaNאוfalseאחרת. אםxהוא לא טנסור, הפונקציה מחזירהNone.הפונקציה
is_sorted(x: Value) -> Valueמוגדרת על טנסורים ומחזירהtrueאם הרכיבים שלxממוינים בסדר עולה ביחס לסדר הלקסיקוגרפי העולה של האינדקסים שלהם, אוfalseאחרת. אםxהוא לא טנזור, הפונקציה מחזירהNone.הפונקציה
is_unique(x: Value) -> Valueמוגדרת לטנסורים ומחזירהtrueאם ל-xאין רכיבים כפולים, אוfalseאחרת. אםxהוא לא טנסור, הפונקציה מחזירהNone.הערך
member_name(x: Value) -> Anyמוגדר לכל הגדרות החבריםmember_nameשל כל הערכים. לדוגמה, הפקודהreal_part(x)מחזירה את החלקRealPartשלComplexConstantהתואם. אםxהוא לא ערך שיש לו חבר מתאים, הפונקציה מחזירה אתNone.הפונקציה
same(x: Value) -> Valueמוגדרת על טנסורים ומחזירהtrueאם כל הרכיבים שלxשווים זה לזה, אוfalseאחרת. אם לטנזור אין אלמנטים, זה נחשב כ'כל האלמנטים שווים זה לזה', כלומר הפונקציה מחזירהtrue. אםxהוא לא טנסור, הפונקציה מחזירהNone.הפונקציה
split(x: Value, num_results: Value, axis: Value) -> Valueמוגדרת על טנסורים ומחזירה פרוסותnum_resultsשלxלאורך הצירaxis. אםxהוא לא טנסור אוdim(x, axis) % num_results != 0, הפונקציה מחזירהNone.הפונקציה
is_defined_in_parent_scope(x: Value) -> Valueמוגדרת במחרוזות ומחזירה את הערךtrueאםxהוא השם של פונקציה שמוגדרת באותו היקף כמו פונקציית ההורה של הפעולה הרלוונטית.הפונקציה
is_namespaced_op_name(x: Value) -> Valueמוגדרת במחרוזות ומחזירהtrueאםxהוא שם אופרטור תקין, כלומר הוא תואם לביטוי הרגולרי הבא:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
חישובים של צורות
axes(x: Value | Placeholder | Type) -> Valueהוא קיצור דרך ל-range(rank(x)).
dim(x: Value | Placeholder | Type, axis: Value) -> Valueהוא קיצור דרך ל-shape(x)[axis].
dims(x: Value | Placeholder | Type, axes: List) -> Listהוא קיצור דרך ל-list(map(lambda axis: dim(x, axis), axes)).
index_space(x: Value | Placeholder | Type) -> Valueמוגדר על טנסורים ומחזיר את האינדקסים שלsize(x)המתאימים, שמוינו בסדר לקסיקוגרפי עולה, כלומר[0, ..., 0],[0, ..., 1]וכן הלאה עדshape(x) - 1.TensorTypeאםxהוא לא סוג טנזור, סוג טנזור שעבר קוונטיזציה, ערך או placeholder של אחד מהסוגים האלה, הפונקציה מחזירה אתx.None
rank(x: Value | Placeholder | Type) -> Valueהוא קיצור דרך ל-size(shape(x)).
shape(x: Value | Placeholder | Type) -> Valueמוגדר בקטע 'פונקציות על סוגים' באמצעותmember_name.
size(x: Value | Placeholder | Type) -> Valueהוא קיצור דרך ל-reduce(lambda x, y: x * y, shape(x)).
חישובי קוונטיזציה
def baseline_element_type(x: Value | Placeholder | Type) -> Typeהיא קיצור דרך לelement_type(baseline_type(x)).
baseline_typeמוגדר בסוגי טנסור ובסוגי טנסור שעברו קוונטיזציה, והוא משנה אותם ל'בסיס', כלומר לסוג עם אותה צורה אבל עם פרמטרים של קוונטיזציה של סוג האלמנט שמאופסים לערכי ברירת מחדל. הטריק הזה שימושי להשוואה אחידה בין שני סוגים של טנסורים: טנסור רגיל וטנסור שעבר קוונטיזציה. השוואה כזו נדרשת לעיתים קרובות. עבור סוגים שעברו קוונטיזציה, האפשרות הזו מאפשרת להשוות בין סוגים תוך התעלמות מפרמטרים של קוונטיזציה, כלומר, כל הפרמטריםshape,storage_type,expressed_type,storage_min,storage_maxו-quantization_dimension(עבור סוג שעבר קוונטיזציה לכל ציר) חייבים להיות זהים, אבל הפרמטריםscalesו-zero pointsיכולים להיות שונים.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
-
dequantizeמוגדר בסוגי טנסורים שעברו קוונטיזציה והופך אותם לסוגי טנסורים של נקודה צפה. ההמרה מתבצעת על ידי המרת רכיבים כמותיים שמייצגים ערכים שלמים של סוג האחסון לערכים תואמים של נקודה צפה של הסוג המבוטא, באמצעות נקודת האפס והקנה מידה שמשויכים לסוג הרכיב הכמותי.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
-
quantizeמוגדר בסוגי טנסור של נקודה צפה והופך אותם לסוגי טנסור שעברו קוונטיזציה. ההמרה מתבצעת על ידי המרת ערכים מסוג נקודה צפה (floating-point) של הסוג המפורש לערכים שלמים תואמים של סוג האחסון, באמצעות נקודת האפס והקנה מידה שמשויכים לסוג האלמנט המכומת.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
-
dequantize_op_quantizeמשמשת לציון חישובים לפי רכיבים בטנסורים שעברו קוונטיזציה. היא מבצעת דה-קוונטיזציה, כלומר הופכת רכיבים שעברו קוונטיזציה לסוגים המפורשים שלהם, ואז מבצעת פעולה, ואז מבצעת קוונטיזציה, כלומר הופכת את התוצאות בחזרה לסוגי האחסון שלהן. בשלב הזה, הפונקציה הזו פועלת רק עבור קוונטיזציה לכל טנסור. אנחנו עובדים על כימות לכל ציר (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
-
hybrid_dequantize_then_opמשמש לציון קוונטיזציה רק של משקלים עבור פעולה היברידית שמקבלת lhs בנקודה צפה ו-rhs בסוגים שעברו קוונטיזציה. היא מבצעת דה-קוונטיזציה של קלטים שעברו קוונטיזציה לסוגים שלהם ומבצעת חישוב בערך float. סוג הרכיב של טנזור lhs מסוג float וסוג הרכיב של טנזור rhs מסוג quantized צריכים להיות זהים.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
חישובים ברשת
cross_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica למעלה.cross_replica(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica למעלה.cross_replica_and_partition(replica_groups: Value) -> Value. אפשר לעיין בקטע cross_replica_and_partition למעלה.flattened_ids(replica_groups: Value) -> Value. אפשר לעיין בקטע 'flattened_ids' למעלה.
דינמיות
לערכים ב-StableHLO יכולים להיות גדלים דינמיים של מאפיינים, למשל tensor<?xi64>.
עם זאת, לערכים ב-StableHLO לא יכול להיות מספר דינמי של מימדים (דינמיות לא מדורגת, למשל tensor<*xi64>). מותר להשתמש באופרנדים ובתוצאות בגדלים דינמיים של מימדים, גם אם יש אילוצים על הגדלים. אם אפשר, המערכת תאמת את האילוצים באופן סטטי. אחרת, היא תדחה את האימות לזמן הריצה, ואי-התאמות יובילו להתנהגות לא מוגדרת. למטה מפורטות דוגמאות.
אי התאמה בצורות של פעולות אלמנטריות אוניטריות
נניח שיש לנו את תוכנית הצעצוע הבאה:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
תוכנית כזו היא לא שגרתית, כי בדרך כלל יודעים מה הצורה של התוצאה אבל לא מה הצורה של הקלט. למרות זאת, זו תוכנית StableHLO תקינה. אי אפשר לאמת באופן סטטי את הפעולה abs בתוכנית הזו, כי הצורה המדויקת של האופרנד לא ידועה. עם זאת, הצורות תואמות בוודאות, ואפשר לבדוק את זה באופן סטטי: יכול להיות ש-? יהפוך ל-2 בזמן הריצה, ולא תהיה בעיה. עם זאת, יכול להיות ש-? יהיה מספר שלם אחר, ובמקרה כזה ההתנהגות לא מוגדרת.
הערה: אם הגודל של מאפיין הוא דינמי בתוצאה, לא יכול להיות מצב של התנהגות לא מוגדרת. למעשה, אין גודל 'צפוי', ולכן לא יכול להיות חוסר התאמה.
אי התאמה בצורות של פעולות בינאריות ברמת הרכיב
נניח שיש לנו את תוכנית הצעצוע הבאה:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
בפעולות בינאריות ברמת הרכיב, הצורות של הקלט והתוצאה חייבות להיות זהות בזמן הריצה. בזמן ההידור, המימדים הסטטיים צריכים להיות שווים, אחרת הם צריכים להיות רק תואמים. אם אחד מהמאפיינים הוא דינמי בקלט, יכול להיות שההתנהגות בזמן הריצה לא תהיה מוגדרת, כי הגודל הדינמי לא יכול להתאים לגודל התואם באופרנד השני (סטטי או דינמי). אם כל נתוני הקלט הם סטטיים, לא משנה אם התוצאה דינמית או לא: מאפיינים סטטיים ידועים ייבדקו באופן סטטי, ומאפיינים דינמיים לא יחייבו אילוצים.
אי התאמות בצורות של פעולות שצורת הפלט שלהן משמשת כאופרנד
נניח שיש לנו את תוכנית הצעצוע הבאה:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
הערכים באופרנד של הצורה בזמן הריצה צריכים להתאים לצורה של התוצאה, אחרת ההתנהגות לא מוגדרת. כלומר, בזמן הריצה הערך של %arg0 חייב להיות dense<[3, 4]> : tensor<2xi32>. אם האופרנד של הצורה הוא קבוע, אפשר לאמת את זה באופן סטטי. אם צורת התוצאה דינמית לחלוטין, לא יכול להיות חוסר התאמה.