מפרט StableHLO

StableHLO היא פעולה שמוגדרת לפעולות ברמה גבוהה (HLO) במודלים של למידת מכונה (ML). הכלי StableHLO פועל כשכבה של ניידות בין מסגרות ML שונות ומהדרים של ML: מסגרות ML שיוצרות תוכניות StableHLO תואמות למהדרי ML שצורכים תוכנות StableHLO.

המטרה שלנו היא לפשט ולהאיץ את הפיתוח של למידת מכונה באמצעות יצירת יכולת פעולה הדדית רבה יותר בין מסגרות שונות של למידת מכונה (כמו TensorFlow, JAX ו-PyTorch) ומהדרים של למידת מכונה (כמו XLA ו-IREE). לשם כך, המסמך מספק מפרט לשפת התכנות StableHLO.

המפרט הזה כולל שלושה קטעים עיקריים. קודם כול, בקטע Programs מתואר המבנה של תוכניות StableHLO, שכוללות פונקציות של StableHLO, שכוללות פעולות של StableHLO. בתוך המבנה הזה, הקטע Ops מציין את הסמנטיקה של פעולות נפרדות. הקטע Execution מספק סמנטיקה של כל הפעולות המתבצעות יחד בתוך תוכנית. לבסוף, הקטע Notation עוסק בסימון שבו נעשה שימוש במפרט.

תוכניות

Program ::= {Func}

תוכנות יציבות (StableHLO) מורכבות ממספר שרירותי של פונקציות StableHLO. לפניכם דוגמה לתוכנית עם פונקציה @main שיש לה 3 מקורות קלט (%image, %weights ו-%bias) ופלט אחד. לגוף הפונקציה יש 6 פעולות.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

פונקציות

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

לפונקציות יציבות (שנקראות גם פונקציות בעלות שם) יש מזהה, קלט/פלט וגוף. בעתיד, אנחנו מתכננים להוסיף מטא-נתונים נוספים לפונקציות כדי להשיג תאימות טובה יותר ל-HLO (#425, #626, #740, #744).

מזהים

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

מזהים יציבים (StableHLO) דומים למזהים בשפות תכנות רבות, בשתי תכונות ייחודיות: 1) לכל המזהים יש סימנים שמבדילים בין סוגים שונים של מזהים, 2) מזהי ערכים יכולים להיות מספריים לחלוטין, כדי לפשט את היצירה של תוכניות StableHLO.

סוגים

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

סוגים יציבים (StableHLO) מסווגים לסוגי ערכים (שנקראים גם סוגי מחלקה ראשונה), שמייצגים ערכי StableHLO וסוגים שאינם ערכים, שמתארים רכיבים אחרים בתוכנית. סוגי StableHLO דומים לסוגים השונים בשפות תכנות רבות, התכונה העיקרית היא האופי הספציפי לדומיין של StableHLO, שמוביל לתוצאות יוצאות דופן (למשל, סוגים סקלרים הם לא סוגי ערכים).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

סוגי טנסורים מייצגים tensors, כלומר מערכים רב-ממדיים. יש להם צורה וסוג אלמנט, כאשר צורה מייצגת גדלי מאפיינים לא שליליים בסדר עולה של המאפיינים התואמים (שמכונים גם צירים) הממוספרים מ-0 עד R-1. מספר המאפיינים R נקרא rank. לדוגמה, tensor<2x3xf32> הוא סוג tensor עם הצורה 2x3 וסוג הרכיב f32. יש לו שני מימדים (או, במילים אחרות, שני צירים) - מימד 0 ומימד ראשון, שהגדלים שלו הם 2 ו-3. הדירוג שלו הוא 2.

כך מגדירים תמיכה בצורות סטטיות שבהן גדלי המימדים ידועים באופן סטטי. בעתיד אנחנו מתכננים להוסיף תמיכה גם בצורות דינמיות, שבהן גודלי המימדים לא ידועים באופן חלקי או מלא (#8). בנוסף, אנחנו מתכננים לחקור הרחבה של סוגי tensor מעבר לגדלים של מאפיינים ולסוגי אלמנטים, למשל, לכלול פריסות (#629) וגמישות (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
שם סוג מגבלות
storage_type סוג מספר שלם (C1-C4), (C9)
storage_min קבוע מספר שלם (C2), (C4), (C8)
storage_max קבוע מספר שלם (C3), (C4), (C8)
expressed_type סוג נקודה צפה (floating-point) (C1), (C5)
quantization_dimension קבוע מספר שלם אופציונלי (C11-C13)
scales מספר וריאנטים של קבועים בנקודה צפה (C5-C7), (C10), (C11), (C13)
zero_points מספר הווריאנטים של קבועים במספרים שלמים (C8-C10)

סוגי רכיבים מכווננים מייצגים ערכים של מספרים שלמים של סוג אחסון בטווח מ-storage_min עד storage_max (כולל) שתואמים לערכי נקודות צפות של סוג ביטוי. עבור ערך נתון של מספר שלם i, ניתן לחשב את הערך התואם של הנקודה הצפה f כ-f = (i - zero_point) * scale, כאשר scale ו-zero_point נקראים פרמטרים של קוונטיזציה. הערכים storage_min ו-storage_max הם אופציונליים בדקדוק, אבל יש להם ערכי ברירת מחדל של min_value(storage_type) ו-max_value(storage_type) בהתאמה. סוגי הרכיבים המרובעים כוללים את המגבלות הבאות:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (ג2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) אם is_empty(quantization_dimension), אז size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

נכון לעכשיו, הערך QuantizationScale הוא קבוע בנקודה צפה, אבל יש התעניינות רבה בסולמות שמבוססים על מספרים שלמים, שמיוצגים באמצעות מכפילים ותזוזות. אנחנו מתכננים לבדוק את הנושא הזה בעתיד הקרוב (#1404).

מתקיים דיון מתמשך על הסמנטיקה של QuantizationZeroPoint, כולל הסוג, הערכים ואם יכולה להיות רק נקודת אפס אחת או אפילו מספר אפסים בסוג t e n pnantor tensor. על סמך התוצאות של הדיון הזה, המפרט לגבי אפס נקודות עשוי להשתנות בעתיד (#1405).

דיון מתמשך נוסף עוסקים בסמנטיקה של QuantizationStorageMin ושל QuantizationStorageMax, כדי לקבוע אם צריך לכפות מגבלות על הערכים האלה ועל הערכים של רכיבי ה-tensor שמככבים (#1406).

לבסוף, אנחנו מתכננים לבחון ייצוג של קני מידה ואפס נקודות לא ידועים, בדומה לאופן שבו אנחנו מתכננים לבחון מידות לא ידועות (#1407).

סוגי tensorות קוונטים מייצגים tensor עם רכיבים שמדמים. ה-tensors האלה זהים ל-tensors הרגילים, אלא שהרכיבים שלהם כוללים סוגי רכיבים שמבוססים על קוונטים, במקום סוגי רכיבים רגילים.

כשמדובר במותחנים ממוספרים, הקוונטיזציה יכולה להיות לכל טנזור. כלומר, צריך להשתמש ב-scale וב-zero_point לכל ה-Tenor או לכל ציר. כלומר, ריבוי scales ו-zero_points, צמד אחד לכל פרוסה של מאפיין מסוים quantization_dimension. באופן רשמי יותר, בשיטה טנזור t עם קוונטיזציה לכל ציר, יש dim(t, quantization_dimension) פרוסות של quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] וכו'. כל הרכיבים בפרוסה הi משתמשים ב-scales[i] וב-zero_points[i] כפרמטרים של הקוונטיזציה. סוגי ה-tensor שממוינים באמצעות קוונטים כוללים את האילוצים הבאים:

  • לקוונטיזציה לכל טנסור:
    • ללא מגבלות נוספות.
  • לקוונטיזציה לפי ציר:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

סוגי אסימונים מייצגים אסימונים, כלומר ערכים אטומים שנוצרים ונצרכות על ידי פעולות מסוימות. האסימונים משמשים לקביעת סדר ביצוע לפעולות, כפי שמתואר בקטע Execution.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

סוגי הצמדים מייצגים צמדים, כלומר רשימות הטרוגניות. Tuples היא תכונה מדור קודם הקיימת רק לצורך תאימות ל-HLO. ב-HLO, צמדים משמשים לייצוג של משתני קלט ופלט שונים. ב-SableHLO יש תמיכה מובנית בקלט ובפלט שונים, והשימוש היחיד ב-Tuples ב-StableHLO הוא ייצוג מקיף של ממשק HLO ABI שבו למשל T, tuple<T> ו-tuple<tuple<T>> עשויים להיות שונים מהותית בהתאם להטמעה מסוימת. בעתיד אנחנו מתכננים לבצע שינויים ב-HLO ABI שעשויים לאפשר לנו להסיר סוגי tuple מ-StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

סוגי רכיבים מייצגים אלמנטים של סוגי tensor. שלא כמו בשפות תכנות רבות, הסוגים האלה הם לא מהשורה הראשונה ב-StableHLO. המשמעות היא שתוכניות יציבות לא יכולות לייצג באופן ישיר ערכים מהסוגים האלה (כתוצאה מכך, אידיומטי לייצג ערכים סקלריים מסוג T עם ערכי tensor 0-ממדיים מסוג tensor<T>).

  • סוג בוליאני מייצג ערכים בוליאניים true ו-false.
  • סוגי מספרים שלמים יכולים להיות חתומים (si) או לא חתומים (ui) ויש להם אחד מרוחבי הביטים הנתמכים (4, 8, 16, 32 או 64). סוגים חתומים מסוג siN מייצגים ערכים של מספרים שלמים מ--2^(N-1) עד 2^(N-1)-1 כולל, וסוגים לא חתומים של uiN מייצגים ערכים של מספרים שלמים מ-0 עד 2^N-1 כולל.
  • סוגים של נקודות צפות יכולות להיות אחת מהאפשרויות הבאות:
  • סוגים מורכבים מייצגים ערכים מורכבים שיש להם חלק ממשי וחלק דמיוני מאותו סוג רכיב. הסוגים המורכבים הנתמכים הם complex<f32> (שני החלקים הם מסוג f32) ו-complex<f64> (שני החלקים הם מסוג f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

סוגי פונקציות מייצגים פונקציות בעלות שם וגם פונקציות אנונימיות. יש להם סוגי קלט (רשימת הסוגים בצד שמאל של ->) וסוגי פלט (רשימת הסוגים בצד ימין של ->). בשפות תכנות רבות, סוגי הפונקציות הם מחלקה ראשונה, אבל לא ב-StableHLO.

StringType ::= 'string'

סוג המחרוזת מייצג רצפים של בייטים. שלא כמו בשפות תכנות רבות, סוג המחרוזת הוא לא מהמחלקה הראשונה ב-StableHLO, והוא משמש רק לציון מטא-נתונים סטטיים לרכיבי תוכנה.

פעולות

פעולות יציבות (שנקראות גם ops) מייצגות קבוצה סגורה של פעולות ברמה גבוהה במודלים של למידת מכונה. כפי שצוין למעלה, התחביר של נתונים יציבים (StableHLO) מבוסס במידה רבה על השימוש ב-MLIR, שהוא לא בהכרח החלופה הארגונומית ביותר, אבל יש ספק שההתאמה הטובה ביותר למטרה של StableHLO היא ליצור יותר יכולת פעולה הדדית בין מסגרות של ML לבין מהדרים של ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

לפעולות יציבות (שנקראות גם ops) יש שם, קלט/פלט וחתימה. השם מורכב מקידומת stablehlo. וממונה שמזהה באופן ייחודי את אחת מהפעולות הנתמכות. בהמשך מופיעה רשימה מקיפה של כל הפעולות הנתמכות.

כרגע, תוכניות של StableHLO בטבע מכילות לפעמים פעולות שלא מתוארות במסמך הזה. בעתיד אנחנו מתכננים לשלב את הפעולות האלה בקבוצת StableHLO, או למנוע את הופעתן בתוכניות StableHLO. בינתיים, הנה רשימת הפעולות הבאות:

  • builtin.module, func.func, func.call ו-func.return (#425).
  • פעולות chlo (#602).
  • "לא בקטגוריית HLO" של פעולות StableHLO – בהתחלה הן היו חלק מקבוצת StableHLO, אבל מאוחר יותר זוהו כלא מתאימים לה: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • קטגוריית "Dynamism" של פעולות StableHLO – בוצעה אתחול מ-MHLO, אבל עדיין לא בדקנו אותן: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (#8).
  • חישובי צורות, כולל פעולות של arith, shape ו-tensor (#8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

פעילויות התפעול צורכות קלט ומפיקות פלטים. הקלט מסווגים לערכי קלט (מחושבים במהלך הביצוע), לפונקציות קלט (מסופקות באופן סטטי, כי בפונקציות StableHLO הן לא ערכים ממחלקה ראשונה) ולמאפייני קלט (גם סטטיות). סוג הקלט והפלט שצורך ויוצר על ידי אופ תלוי במבנה שלו. לדוגמה, הפעולה add צורכת שני ערכי קלט ומפיקה ערך פלט אחד. לשם השוואה, הפעולה select_and_scatter צורכת 3 ערכי קלט, 2 פונקציות קלט ו-3 מאפייני קלט.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

פונקציות קלט (שנקראות גם פונקציות אנונימיות) דומות מאוד לפונקציות בעלות שם, אבל: 1) אין להן מזהה (ולכן השם 'אנונימי'), 2) הן לא מצהירות על סוגי הפלט (סוגי הפלט מוסקים מהפעולה return בתוך הפונקציה).

התחביר של פונקציות הקלט כולל חלק שלא נמצא כרגע בשימוש (ראו גרסת הייצור Unused למעלה), שנמצא תואם ל-MLIR. ב-MLIR יש תפיסה כללית יותר של "אזורים", שיכולים לכלול מספר "בלוקים" של פעולות המחוברות זה לזה באמצעות פעולות קפיצה. לבלוקים האלה יש מזהים שתואמים לסביבת הייצור של Unused, כך שאפשר להבחין ביניהם. ב-SableHLO אין פעולות קפיצה, ולכן החלק התואם של תחביר MLIR לא בשימוש (אבל עדיין קיים).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

למאפייני קלט יש שם וערך, שהוא אחד מהקבועים הנתמכים. זוהי הדרך העיקרית לציון מטא-נתונים סטטיים לרכיבי תוכנה. לדוגמה, הפעולה concatenate משתמשת במאפיין dimension כדי לציין את המאפיין שלאורכו ערכי הקלט משורשרים. באופן דומה, הפעולה slice משתמשת במספר מאפיינים כמו start_indices ו-limit_indices כדי לציין את הגבולות שישמשו לחיתוך של ערך הקלט.

כרגע, תוכניות של StableHLO בטבע מכילות לפעמים מאפיינים שלא מתוארים במסמך הזה. בעתיד אנחנו מתכננים להטמיע את המאפיינים האלה בקבוצת StableHLO או למנוע את הופעתם בתוכניות StableHLO. בינתיים, ריכזנו כאן רשימה של המאפיינים הבאים:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • מטא-נתונים של מיקום (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

חתימת Op מורכבת מהסוגים של כל ערכי הקלט (רשימת הסוגים בצד שמאל של ->) ומהסוגים של כל ערכי הפלט (רשימת הסוגים בצד הימני של ->). באופן כללי, סוגי הקלט הם מיותרים, וכמו כן סוגי הפלט כמעט תמיד מיותרים (מכיוון שלרוב פעולות StableHLO אפשר להסיק מהם סוגי הפלט). עם זאת, החתימה הדיגיטלית היא חלק מתחביר StableHLO שמתאים ל-MLIR.

למטה מוצגת דוגמה לפעולה שהזיכרון שלה הוא select_and_scatter. היא צורכת 3 ערכי קלט (%operand, %source ו-%init_value), 2 פונקציות קלט ו-3 מאפייני קלט (window_dimensions, window_strides ו-padding). חשוב לשים לב שהחתימה של הפעולה כוללת רק את הסוגים של ערכי הקלט שלה (אבל לא את הסוגים של פונקציות ומאפייני קלט שסופקו בתוך השורה).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

קבועים

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

לקבועים יציבים ב-StableHLO יש ליטרל וסוג, שביחד מייצגים ערך StableHLO. בדרך כלל, הסוג הוא חלק מהתחביר הקבוע, אלא אם הוא חד-משמעי (למשל, קבוע בוליאני מכיל את הסוג i1 באופן חד-משמעי, בעוד שלקבוע מספר שלם יכולים להיות כמה סוגים אפשריים).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

קבועים בוליאניים מייצגים ערכים בוליאניים true ו-false. לקבועים בוליאניים יש סוג i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

קבועים במספרים שלמים מייצגים ערכים של מספרים שלמים באמצעות מחרוזות שמשתמשות בסימון עשרוני או הקסדצימלי. אין תמיכה בבסיסים אחרים, למשל בינארי או אוקטלי. לקבועים במספרים שלמים יש את המגבלות הבאות:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

קבועים נקודה צפה (floating-point) מייצגים ערכים של נקודה צפה באמצעות מחרוזות שמשתמשות בסימון עשרוני או מדעי. אפשר גם להשתמש בסימון הקסדצימלי כדי לציין ישירות את הסיביות הבסיסיות בפורמט נקודה צפה (floating-point) מהסוג המתאים. על קבועים נקודה צפה יש את האילוצים הבאים:

  • (C1) אם נעשה שימוש בייצוג לא הקסדצימלי, is_wellformed(float_literal, float_type).
  • (C2) אם משתמשים בייצוג הקסדצימלי, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

קבועים מורכבים מייצגים ערכים מורכבים באמצעות רשימות של חלק ממשי (קודם) וחלק דמיוני (מופיע שני). לדוגמה, הקוד (1.0, 0.0) : complex<f32> מייצג את 1.0 + 0.0i, והחלק (0.0, 1.0) : complex<f32> מייצג את 0.0 + 1.0i. ההטמעה מתבצעת לפי הסדר שבו החלקים האלה מאוחסנים בזיכרון. לקבועים מורכבים יש את המגבלות הבאות:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (ג2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

קבועים של tensor מייצגים ערכי tensor באמצעות רשימות מקננות שמצוינות באמצעות הסימון NumPy. לדוגמה, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> מייצג ערך tensor עם המיפוי הבא מאינקים לרכיבים: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. ההטמעה היא הסדר שבו מאוחסנים הרכיבים האלה בזיכרון. על קבועי Tensor יש את האילוצים הבאים:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), כאשר:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), כאשר:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • אחרת, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

קבועי טנזור מכווננים מייצגים ערכי tensor מכווננים באמצעות אותו סימון של קבועי tensor, עם אלמנטים שצוינו כקבועים בסוג האחסון שלהם. על קבועי טנזור מחושבים יש את האילוצים הבאים:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (ג2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

ליטרלים של מחרוזת מורכבים מבייטים שצוינו באמצעות תווי ASCII ורצפי בריחה. אלו הם קידודים אגנוסטיים, ולכן הפירוש של הבייטים האלה מוגדר לפי היישום. ליטרלים של מחרוזת הם מסוג string.

תפעול

abs

סמנטיקה

מבצע את פעולת ה-AB של כל אלמנט על טנסור operand ויוצר tensor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים וחתומים: מודולוס של מספר שלם.
  • לצפים: abs מ-IEEE-754.
  • למספרים מרוכבים: מודול מורכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(abs, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, t e n s o d s l o re ג, או ר בל ט ר נו ה בשביל ת מור ט אן ט 1 ט 1 ש 1111111111111111 כבר{/1} ה ה (C1-C2)

פלט

שם סוג מגבלות
result tensor של מספר שלם חתום או סוג של נקודה צפה (floating-point) או img_tensor p-tensor (C1-C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • baseline_element_type(operand) אחרת.

דוגמאות

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

דוגמאות נוספות

add

סמנטיקה

הפעולה הזו מבצעת הוספה של שני טנזורים של lhs ו-rhs ברמת הרכיב, ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • עבור בוליאנים: OR לוגי.
  • למספרים שלמים: חיבור של מספר שלם.
  • לצפים: addition מ-IEEE-754.
  • למספרים מרוכבים: חיבור מרוכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(add, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1)
(I2) rhs tensor, או tensor, quanted tensor (C1)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

דוגמאות נוספות

after_all

סמנטיקה

מוודא שהפעולות שמייצרות את inputs יבוצעו לפני כל פעולה שתלויה ב-result. הביצוע של הפעולה הזאת לא משנה דבר, אלא רק כדי ליצור יחסי תלות של נתונים מ-result עד inputs.

קלט

תווית שם סוג
(I1) inputs מספר הווריאנטים של token

פלט

שם סוג
result token

דוגמאות

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

דוגמאות נוספות

all_gather

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, משורשרים הערכים של ה-tensor operand מכל תהליך לאורך all_gather_dim ויוצרים טנטור result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • operands@receiver = [operand@sender for sender in process_group] לכל receiver באפליקציה process_group.
  • result@process = concatenate(operands@process, all_gather_dim) לכל process באפליקציה process_group.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C6)
(I2) all_gather_dim קבוע מסוג si64 (C1), (C6)
(I3) replica_groups קבוע tensor דו-ממדי מסוג si64 (C2-C4)
(I4) channel_id קבוע מסוג si64 (C5)
(I5) use_global_device_ids קבוע מסוג i1 (C5)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C6)

מגבלות

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (ג2) is_unique(replica_groups).
  • (C3) השדה size(replica_groups) מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_replicas אם נעשה שימוש בcross_replica_and_partition.
    • num_processes אם נעשה שימוש בflattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) אם הערך הוא use_global_device_ids = true, אז channel_id > 0.
  • (C6) type(result) = type(operand) מלבד:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

דוגמאות נוספות

all_reduce

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, ההחלה של פונקציית הפחתה computation על הערכים של ה-tensor operand מכל תהליך יוצרת טנזור result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • result@process[result_index] = exec(schedule) לעץ בינארי מסוים schedule שבו:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי שמוגדר על ידי הטמעה, שהמעבר שלו לפי סדר הוא to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C5), (C6)
(I2) replica_groups מספר הווריאנטים של קבועי טנזור חד-ממדיים מסוג si64 (C1-C3)
(I3) channel_id קבוע מסוג si64 (C4)
(I4) use_global_device_ids קבוע מסוג i1 (C4)
(I5) computation פונקציה (C5)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C6-C7)

מגבלות

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_replicas אם נעשה שימוש בcross_replica_and_partition.
    • num_processes אם נעשה שימוש בflattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) אם הערך הוא use_global_device_ids = true, אז channel_id > 0.
  • (C5) ב-computation יש סוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

דוגמאות נוספות

all_to_all

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, הפיצול של הערכים של הטנזור operand לאורך split_dimension לחלקים, פיזור של החלקים המפוצלים בין התהליכים, שרשור החלקים המפוזרים לאורך concat_dimension ויוצרת טנזור result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0.
  • cross_partition(replica_groups) אם channel_id > 0.

לאחר מכן, בכל process_group:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) לכל ה-sender ב-process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group] כאשר receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1-C3), (C9)
(I2) split_dimension קבוע מסוג si64 (C1), (C2), (C9)
(I3) concat_dimension קבוע מסוג si64 (C3), (C9)
(I4) split_count קבוע מסוג si64 (C2), (C4), (C8), (C9)
(I5) replica_groups קבוע tensor דו-ממדי מסוג si64 (C5-C8)
(I6) channel_id קבוע מסוג si64

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C9)

מגבלות

  • (C1) 0 <= split_dimension < rank(operand).
  • (ג2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_partitions אם נעשה שימוש בcross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand) מלבד:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

דוגמאות נוספות

וגם

סמנטיקה

הפונקציה מבצעת את הפונקציה AND ברמת הרכיב של שני tensor lhs ו-rhs ויוצרת result tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • בשביל בוליאנים: logic AND.
  • למספרים שלמים: AND ברמת הסיביות.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג בוליאני או מספר שלם (C1)
(I2) rhs tensor מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

סמנטיקה

הפונקציה מבצעת פעולת atan2 של הרכיב ברמת הרכיב ב-lhs ו-rhs tensor, ויוצרת tanor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: atan2 מ-IEEE-754.
  • למספרים מרוכבים: atan2 מרוכבים.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)
(I2) rhs t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

דוגמאות נוספות

batch_norm_grad

סמנטיקה

מחשבת את ההדרגתיות של מספר מקורות קלט של batch_norm_training שהפצה לאחור מ-grad_output, ומפיקה את הטנזורים grad_operand, grad_scale ו-grad_offset. באופן רשמי יותר, אפשר לבטא את הפעולה הזו פירוק לפעולות StableHLO קיימות באמצעות תחביר Python, באופן הבא:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1-C3), (C5)
(I2) scale tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4), (C5)
(I3) mean tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4)
(I4) variance tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4)
(I5) grad_output t e n s o r f l o w, או של t e n s o l l o w, (C2), (C3)
(I6) epsilon קבוע מסוג f32
(I7) feature_index קבוע מסוג si64 (C1), (C5)

פלט

שם סוג מגבלות
grad_operand t e n s o r f l o w, או של t e n s o l l o w, (C2), (C3)
grad_scale tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4)
grad_offset tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale ו-grad_offset כוללים את אותה baseline_element_type.
  • (C3) operand, grad_output ו-grad_operand הם בצורה זהה.
  • (C4) scale, mean, variance, grad_scale ו-grad_offset הם בצורה זהה.
  • (C5) size(scale) = dim(operand, feature_index).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

סמנטיקה

מנרמלת את הטנזור operand בכל המימדים מלבד המימד feature_index ויוצרת טנזור result. באופן רשמי יותר, אפשר לבטא את הפעולה הזו פירוק לפעולות קיימות ב-StableHLO באמצעות תחביר Python:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1-C7)
(I2) scale tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C3)
(I3) offset tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C4)
(I4) mean tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C5)
(I5) variance tensor חד-מימדי מסוג נקודה צפה (floating-point) או לכל טנזור (C2), (C6)
(I6) epsilon קבוע מסוג f32
(I7) feature_index קבוע מסוג si64 (C1), (C3-C6)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או של t e n s o l l o w, (C2), (C7)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance ו-result כוללים את אותו baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

סמנטיקה

מחשבת את הממוצע והשונות בכל המימדים מלבד המאפיין feature_index, ומנרמלת את הטנזור operand שיוצר את ה-output, ה-batch_mean וה-batch_var. באופן רשמי יותר, אפשר לבטא את הפעולה הזו פירוק לפעולות קיימות ב-StableHLO באמצעות תחביר Python, באופן הבא:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)
(I2) scale טנזור חד-ממדי של נקודה צפה (floating-point) או ר-זמן-מספר (C2), (C3)
(I3) offset טנזור חד-ממדי של נקודה צפה (floating-point) או ר-זמן-מספר (C2), (C4)
(I4) epsilon קבוע מסוג f32 (C1), (C3-C6)
(I5) feature_index קבוע מסוג si64 (C1), (C3-C6)

פלט

שם סוג מגבלות
output t e n s o r f l o w, או של t e n s o l l o w, (C7)
batch_mean טנזור חד-ממדי של נקודה צפה (floating-point) או ר-זמן-מספר (C2), (C5)
batch_var טנזור חד-ממדי של נקודה צפה (floating-point) או ר-זמן-מספר (C2), (C6)

מגבלות

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var ו-output יש את אותו baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

דוגמאות

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

סמנטיקה

מבצע פעולת Bitcast ב-Tenor operand ויוצר tenor של result, שבו הסיביות של ה-Tenor המלא של operand מפורשות מחדש באמצעות הסוג של ה-Tenor result.

באופן רשמי יותר, ניתן לך: E = element_type(operand), E' = element_type(result) ו-R = rank(operand):

  • אם num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • אם num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • אם num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits מחזירה ייצוג בזיכרון של ערך נתון, וההתנהגות שלו מוגדרת על ידי היישום, כי הייצוג המדויק של רכיבי ה-tensor מוגדר על ידי היישום, וגם הייצוג המדויק של סוגי הרכיבים מוגדר על ידי ההטמעה.

קלט

תווית שם סוג מגבלות
(I1) operand את img_tensor, או tensor, pantor (C1-C2)

פלט

שם סוג מגבלות
result את img_tensor, או tensor, pantor (C1-C2)

מגבלות

  • (C1) בהינתן E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) ו-R = rank(operand):
    • אם num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • אם num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • אם num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) לכל 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) אם הערך הוא is_complex(operand) or is_complex(result), אז is_complex(operand) and is_complex(result).

דוגמאות

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

דוגמאות נוספות

broadcast_in_dim

סמנטיקה

היא מרחיבה את המימדים ו/או את הדירוג של טנסור קלט על ידי שכפול הנתונים ב-Tenor של operand ויוצרת טנסור result. באופן רשמי יותר, result[result_index] = operand[operand_index] שבו כל הd ב-axes(operand):

  • operand_index[d] = 0 אם dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand את img_tensor, או tensor, pantor (C1-C2), (C5-C6)
(I2) broadcast_dimensions קבוע tensor חד-ממדי מסוג si64 (C2-C6)

פלט

שם סוג מגבלות
result את img_tensor, או tensor, pantor (C1), (C3), (C5-C6)

מגבלות

  • (C1) הערך element_type(result) ניתן על ידי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) מלבד האופן שבו quantization_dimension(operand), scales(operand) ו-zero_points(operand) עשויים להיות שונים מהתשובה quantization_dimension(result), scales(result) ו-zero_points(result), אחרת.
  • (ג2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) לכל d ביחידה הארגונית axes(operand):
    • dim(operand, d) = 1 או
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) אם is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • אם הערך הוא dim(operand, quantization_dimension(operand)) = 1, אז scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

דוגמאות

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

דוגמאות נוספות

כיסוי

סמנטיקה

הפונקציה מפיקה את הפלט מהפעלה של פונקציה אחת בדיוק מ-branches, בהתאם לערך של index. בצורה רשמית יותר, result = selected_branch() כאשר:

  • selected_branch = branches[index] אם 0 <= index < size(branches).
  • selected_branch = branches[-1] אחרת.

קלט

תווית שם סוג מגבלות
(I1) index Tensor 0 ממדי מסוג si32
(I2) branches מספר הווריאנטים של הפונקציות (C1-C4)

פלט

שם סוג מגבלות
results מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C4)

מגבלות

  • (C1) 0 < size(branches).
  • (ג2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

דוגמאות

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

דוגמאות נוספות

Cbrt

סמנטיקה

מבצע פעולה של שורש מעוקב ברמת הרכיב על טנסור operand ויוצר טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: rootn(x, 3) מ-IEEE-754.
  • למספרים מרוכבים: שורש מעוקב מורכב.
  • לסוגים שמחושבים לפי כמות: dequantize_op_quantize(cbrt, operand, type(result))

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

דוגמאות נוספות

CEil

סמנטיקה

ביצוע של CEil ברמת הרכיב של טנסור operand ומפיק טנזור result. מטמיע את הפעולה roundToIntegralTowardPositive ממפרט IEEE-754. לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(ceil, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או של t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

דוגמאות נוספות

צ'ולסקי

סמנטיקה

מחשבת את פירוק כולסקי של קבוצת מטריצות.

באופן רשמי יותר, עבור כל הi ב-index_space(result), result[i0, ..., iR-3, :, :] הוא פירוק כולסקי של a[i0, ..., iR-3, :, :], בצורת מטריצה של משולש תחתון (אם lower הוא true) או מטריצת משולשת עליונה (אם lower הוא false). ערכי הפלט במשולש הנגדי, כלומר המשולש העליון המחמיר או המשולש הקפדני התחתון, מוגדרים בהתאם.

אם קיים i שבו מטריצת הקלט אינה מטריצה מוגדרת החיובית הרמיטית, אז ההתנהגות לא מוגדרת.

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1-C3)
(I2) lower קבוע tensor 0 ממדי מסוג i1

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(a) = baseline_type(result).
  • (ג2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

דוגמאות

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

מהדק

סמנטיקה

מצמידת כל רכיב של הטנזור operand בין ערך מינימלי למקסימום, ויוצרת טנזור result. באופן רשמי יותר, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), כאשר min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. לסוגים שמחושבים לפי כמות, מבצעים את dequantize_op_quantize(clamp, min, operand, max, type(result)).

יצירת סדר מספרים מרוכבים כרוכה בסמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה זו (#560).

קלט

תווית שם סוג מגבלות
(I1) min tensor, או tensor, quanted tensor (C1), (C3)
(I2) operand tensor, או tensor, quanted tensor (C1-C4)
(I3) max tensor, או tensor, quanted tensor (C2), (C3)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C4)

מגבלות

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (ג2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

דוגמאות

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

דוגמאות נוספות

collective_broadcast

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, שולחים את הערך של ה-tensor operand מתהליך המקור לתהליכי היעד ויוצרים t tensor result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0.
  • cross_partition(replica_groups) אם channel_id > 0.

לאחר מכן, result@process ניתן על ידי:

  • operand@process_groups[i, 0] אם קיים i שהתהליך הוא ב-process_groups[i].
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand Tensor (C3)
(I2) replica_groups מספר הווריאנטים של קבועי טנזור חד-ממדיים מסוג si64 (C1), (C2)
(I3) channel_id קבוע מסוג si64

פלט

שם סוג מגבלות
result Tensor (C3)

מגבלות

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N כאשר N מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_partitions אם נעשה שימוש בcross_partition.
  • (C3) type(result) = type(operand).

דוגמאות

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, נשלחת הערך של ה-tensor operand מתהליך המקור לתהליך היעד, ויוצרת טנזור result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(source_target_pairs) אם channel_id <= 0.
  • cross_partition(source_target_pairs) אם channel_id > 0.

לאחר מכן, result@process ניתן על ידי:

  • operand@process_groups[i, 0], אם קיים i כזה process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C5)
(I2) source_target_pairs קבוע tensor דו-ממדי מסוג si64 (C1-C4)
(I3) channel_id קבוע מסוג si64

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) dim(source_target_pairs, 1) = 2.
  • (ג2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, כאשר N מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_partitions אם נעשה שימוש בcross_partition.
  • (C5) type(result) = type(operand).

דוגמאות

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

דוגמאות נוספות

השוואה

סמנטיקה

מבצעת השוואה ברמת הרכיב של lhs ו-rhs tensor לפי comparison_direction ו-compare_type, ויוצרת טנזור result.

לערכים של comparison_direction ו-compare_type יש את הסמנטיקה הבאה:

לרכיבים בוליאניים ומספרים שלמים:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

בסוגי רכיבים של נקודה צפה עם compare_type = FLOAT, הפעולה מטמיעה את פעולות IEEE-754 הבאות:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

עבור סוגי רכיבים עם נקודה צפה (floating-point) עם compare_type = TOTALORDER, הפעולה משתמשת בשילוב של הפעולות totalOrder ו-compareQuietEqual מ-IEEE-754. נראה שתכונה זו אינה בשימוש ולכן בעתיד אנחנו מתכננים להסיר אותה (#584).

בסוגי אלמנטים מורכבים, ההשוואה הליקסיקוגרפית של צמדי (real, imag) מתבצעת באמצעות comparison_direction ו-compare_type שסופקו. יצירת סדר מספרים מרוכבים כרוכה בסמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים כאשר comparison_direction הוא GE, GT, LE או LT (#560).

לסוגים שמחושבים לפי כמות. הביצועים של dequantize_compare(lhs, rhs, comparison_direction).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1-C3)
(I2) rhs tensor, או tensor, quanted tensor (C1-C2)
(I3) comparison_direction טיפוסים בני מנייה (enum) של EQ, NE, GE, GT, LE ו-LT
(I4) compare_type טיפוסים בני מנייה (enum) של FLOAT, TOTALORDER, SIGNED ו-UNSIGNED (C3)

פלט

שם סוג מגבלות
result tensor מסוג בוליאני (C2)

מגבלות

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (ג2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) השדה compare_type מוגדר כך:
    • SIGNED אם is_signed_integer(element_type(lhs)).
    • UNSIGNED אם is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT או TOTALORDER אם is_float(element_type(lhs)).
    • FLOAT אם is_complex(element_type(lhs)).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

דוגמאות נוספות

מורכב

סמנטיקה

הפונקציה מבצעת המרה של רכיב מסוים לערך מורכב מזוג ערכים ממשיים ומדומים, lhs ו-rhs, ויוצרת טנסור result.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג f32 או f64 (C1-C3)
(I2) rhs tensor מסוג f32 או f64 (C1)

פלט

שם סוג מגבלות
result Tensor מסוג מורכב (C2), (C3)

מגבלות

  • (C1) type(lhs) = type(rhs).
  • (ג2) shape(result) = shape(lhs).
  • (C3) ב-element_type(result) יש סוג complex<E> כאשר E = element_type(lhs).

דוגמאות

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

דוגמאות נוספות

concatenate

סמנטיקה

משורשרת את inputs לאורך המימד dimension באותו סדר של הארגומנטים הנתונים ויוצרת טנזור result. באופן רשמי יותר, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], שבו:

  1. id = d0 + ... + dk-1 + kd.
  2. d שווה ל-dimension, ו-d0, ... הן מידות ה-d של inputs.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1-C6)
(I2) dimension קבוע מסוג si64 (C2), (C4), (C6)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C5-C6)

מגבלות

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) מלבד dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) למעט:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

דוגמאות

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

דוגמאות נוספות

קבוע

סמנטיקה

הפונקציה מפיקה טנזור output מ-value קבוע.

קלט

תווית שם סוג מגבלות
(I1) value קבוע (C1)

פלט

שם סוג מגבלות
output את img_tensor, או tensor, pantor (C1)

מגבלות

  • (C1) type(value) = type(output).

דוגמאות

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

דוגמאות נוספות

להשלים המרה

סמנטיקה

הפונקציה מבצעת המרה של הרכיב מסוג אחד לסוג אחר ב-Tenor של operand ויוצרת tensor result.

בהמרות מסוג boolean-to-any-supported-type, הערך false מומר לאפס והערך true מומר ל-1. במקרה של המרות any-supported-type-to-boolean, ערך אפס מומר ל-false והערכים שאינם אפס עוברים המרה ל-true. בהמשך מוסבר איך זה עובד בסוגים מורכבים.

להמרות שכוללות מספר שלם למספר שלם, integer-to-floating-point, או floating-point-to-floating-point, אם אפשר לייצוג מדויק של ערך המקור בסוג היעד, ערך התוצאה הוא הייצוג המדויק. אחרת, ההתנהגות היא TBD (#180).

בהמרות שכוללות floating-point-to-integer, החלק השבר ייחתך. אם אי אפשר לייצג את הערך שנחתך בסוג היעד, ההתנהגות תהיה TBD (#180).

המרות מסוג מורכבות ל-complex פועלות באותה התנהגות של המרות מסוג floating-point-to-floating-point להמרת חלקים ממשיים ומדומים.

בהמרות מסוג complex-to-any-other-type וcomplex-to-any-other-type, המערכת מתעלמת מהערך המדומה המקורי או שהערך המדומה של היעד לא מקבל ערך, בהתאמה. ההמרה של החלק האמיתי מתרחשת בהמרות של הנקודה הצפה.

בעיקרון, הפעולה הזו יכולה לבטא את פעולת ביטול הקוונטיזציה (המרה מ-tenors tensors (tensor) במאוזן ל-tensors (tensors) רגילים), quantization (המרה מ-tensors רגילים - tensors מותנים), ו-Requantization (המרה בין tensors quantated), אך כרגע הגדרנו פעולות לכך - uniform_dequantize לתרחיש לדוגמה הראשון ו-uniform_quantize לתרחיש לדוגמה השלישי. בעתיד, ייתכן ששתי הפעולות האלה ימוזגו ל-convert (#1576).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor (C1)

פלט

שם סוג מגבלות
result Tensor (C1)

מגבלות

  • (C1) shape(operand) = shape(result).

דוגמאות

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

דוגמאות נוספות

קונבולציה

סמנטיקה

הפונקציה מחשבת נקודות מוצרים בין חלונות של lhs ופרוסות של rhs ויוצרת result. התרשים הבא מראה כיצד מחושבים הרכיבים ב-result מ-lhs ומ-rhs באמצעות דוגמה קונקרטית.

באופן יותר רשמי, כדי לבטא חלונות של lhs, כדאי לשנות את הפריים של הקלט בעזרת המונחים lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

שינוי הפריים הזה משתמש בפונקציות המסייעות הבאות:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] כאשר j[d] = i[permutation[d]].

אם feature_group_count = 1 וגם batch_group_count = 1, אז לכל output_spatial_index ב-index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product כאשר:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). נראה שהתכונה הזו לא בשימוש, כך שבעתיד אנחנו מתכננים להסיר אותה (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

אם feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

אם batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1), (C10-C11), (C14) (C25), (C27-C30)
(I2) rhs את img_tensor, או tensor, pantor (C1), (C14-C16), (C25), (C27-C32)
(I3) window_strides קבוע tensor חד-ממדי מסוג si64 (C2-C3), (C25)
(I4) padding קבוע tensor דו-ממדי מסוג si64 (C4), (C25)
(I5) lhs_dilation קבוע tensor חד-ממדי מסוג si64 (C5-C6), (C25)
(I6) rhs_dilation קבוע tensor חד-ממדי מסוג si64 (C7-C8), (C25)
(I7) window_reversal קבוע tensor חד-ממדי מסוג i1 (C9)
(I8) input_batch_dimension קבוע מסוג si64 (C10), (C13), (C25)
(I9) input_feature_dimension קבוע מסוג si64 (C11), (C13-C14)
(I10) input_spatial_dimensions קבוע tensor חד-ממדי מסוג si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension קבוע מסוג si64 (C14), (C18)
(I12) kernel_output_feature_dimension קבוע מסוג si64 (C15-C16), (C18), (C25), (C32)
(I13) kernel_spatial_dimensions קבוע tensor חד-ממדי מסוג si64 (C17-C18), (C25)
(I14) output_batch_dimension קבוע מסוג si64 (C20), (C25)
(I15) output_feature_dimension קבוע מסוג si64 (C20), (C25), (C33)
(I16) output_spatial_dimensions קבוע tensor חד-ממדי מסוג si64 (C19-C20), (C25)
(I17) feature_group_count קבוע מסוג si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count קבוע מסוג si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config מספר וריאנטים של טיפוסים בני DEFAULT, HIGH ו-HIGHEST (C24)

פלט

שם סוג מגבלות
result את img_tensor, או tensor, pantor (C25-C28), (C30-C31), (C33)

מגבלות

  • (C1) N = rank(lhs) = rank(rhs).
  • (ג2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) בהינתן input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) בהינתן kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) בהינתן output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) מוגדר כך:
    • dim(lhs, input_batch_dimension) / batch_group_count אם result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) אם result_dim = output_feature_dimension.
    • num_windows, אחרת, כאשר:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • אם בפעולה נעשה שימוש בתנורים לא מכווננים:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • אם בפעולה נעשה שימוש במנזרים קוונטיים:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) אם הערך הוא is_per_tensor_quantized(rhs), אז is_per_tensor_quantized(result).
    • (C32) אם הערך הוא is_per_axis_quantized(rhs), אז quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) אם is_per_axis_quantized(result), אז quantization_dimension(result) = output_feature_dimension.

דוגמאות

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

קוסינוס

סמנטיקה

מבצעת פעולת קוסינוס של רכיב מסוים על טנסור operand ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: cos מ-IEEE-754.
  • למספרים מרוכבים: קוסינוס מרוכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(cosine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

דוגמאות נוספות

count_leading_zeros

סמנטיקה

הפונקציה מבצעת ספירה של מספר הביטים המובילים ב-Tenor של operand ויוצרת טנזור result.

קלט

תווית שם סוג מגבלות
(I1) operand tensor מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

דוגמאות נוספות

custom_call

סמנטיקה

הפונקציה מחשבת פעולה מוגדרת על ידי הטמעה call_target_name, שמתחילה ב-inputs וב-called_computations ויוצרת results. ייתכן שייעשה שימוש ב-has_side_effect, ב-backend_config וב-api_version כדי לספק מטא-נתונים נוספים שהוגדרו על ידי ההטמעה.

נכון לעכשיו, הפעולה הזו מכילה אוסף לא מאורגן של מטא-נתונים, המשקף את ההתפתחות האורגנית של פעולות המקבילות שלו במהדר XLA. בעתיד, אנחנו מתכננים לאחד את המטא-נתונים האלה (#741).

קלט

תווית שם סוג
(I1) inputs מספר הווריאנטים של הערכים
(I2) call_target_name קבוע מסוג string
(I3) has_side_effect קבוע מסוג i1
(I4) backend_config קבוע מסוג string
(I5) api_version קבוע מסוג si32
(I6) called_computations מספר קבוע של קבועים מסוג string

פלט

שם סוג
results מספר הווריאנטים של הערכים

דוגמאות

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

חילוק

סמנטיקה

הפונקציה מבצעת את החילוק של המחלק של lhs ואת המחלק rhs של טנזור המחלק, ברמת הרכיב, ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים: חלוקה של מספרים שלמים, שיוצרת את המנה האלגברית בלי להסיר חלק חלקי.
  • לצפים: division מ-IEEE-754.
  • למספרים מרוכבים: חילוק מורכב.
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)
(I2) rhs tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

דוגמאות נוספות

dot_general

סמנטיקה

הפונקציה מחשבת את המכפלות המנוקדות בין פרוסות של lhs לפרוסות rhs ויוצרת tensor result.

באופן רשמי יותר, result[result_index] = dot_product, איפה:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index כאשר size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) ו-size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

מציין סמנטיקה רק עבור קוונטיזציה לכל טנזור. אנחנו עובדים על קוונטיזציה לפי ציר (#1574). כמו כן, בעתיד יכול להיות שנשקול להוסיף תמיכה בקוונטיזציה היברידית (#1575).

precision_config קובע את האיזון בין מהירות לדיוק בחישובים בקצוות העורפיים של מאיץ. זה יכול להיות אחת מהאפשרויות הבאות (בשלב הזה, הסמנטיקה של ערכי ה-enum לא צוינה, אבל אנחנו מתכננים לטפל בכך ב-#755):

  • DEFAULT: החישוב המהיר ביותר, אבל ההערכה פחות מדויקת למספר המקורי.
  • HIGH: חישוב איטי יותר, אבל הערכה מדויקת יותר למספר המקורי.
  • HIGHEST: החישוב האיטי ביותר, אבל ההערכה המדויקת ביותר למספר המקורי.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C5-C6), (C9-C10), (C12-C16)
(I2) rhs tensor, או tensor, quanted tensor (C7-C10), (C12)
(I3) lhs_batching_dimensions קבוע tensor חד-ממדי מסוג si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions קבוע tensor חד-ממדי מסוג si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions קבוע tensor חד-ממדי מסוג si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions קבוע tensor חד-ממדי מסוג si64 (C2), (C4), (C8), (C10)
(I7) precision_config מספר וריאנטים של טיפוסים בני DEFAULT, HIGH ו-HIGHEST (C11)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C12), (C14), (C16)

מגבלות

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (ג2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • אם בפעולה נעשה שימוש בתנורים לא מכווננים:
    • (C13) element_type(lhs) = element_type(rhs).
  • אם בפעולה נעשה שימוש במנזרים קוונטיים:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

דוגמאות

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

דוגמאות נוספות

dynamic_slice

סמנטיקה

מוציאה פרוסה מה-operand באמצעות אינדקסים ראשוניים שמחושבים באופן דינמי ויוצרת טנזור result. start_indices מכילים את האינדקסים המתחילים של הפרוסה לכל מאפיין הכפוף להתאמה פוטנציאלית, ו-slice_sizes כוללים את גודלי הפלח בכל מאפיין. בצורה רשמית יותר, result[result_index] = operand[operand_index] כאשר:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C2), (C4)
(I2) start_indices מספר וריאנטים של 0 מימדים מסוג מספר שלם (C2), (C3)
(I3) slice_sizes קבוע tensor חד-ממדי מסוג si64 (C2), (C4), (C5)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1), (C5)

מגבלות

  • (C1) element_type(operand) = element_type(result).
  • (ג2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

דוגמאות

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

דוגמאות נוספות

dynamic_update_slice

סמנטיקה

הפונקציה יוצרת טנסור result ששווה לטנזור operand, אבל הקטע שמתחיל ב-start_indices מתעדכן בערכים שב-update. באופן רשמי יותר, ההגדרה result[result_index] היא:

  • update[update_index] אם 0 <= update_index < shape(update) כאשר:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1-C4), (C6)
(I2) update tensor, או tensor, quanted tensor (C2), (C3), (C6)
(I3) start_indices מספר וריאנטים של 0 מימדים מסוג מספר שלם (C4), (C5)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) type(operand) = type(result).
  • (ג2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

דוגמאות

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

דוגמאות נוספות

מעריכיות

סמנטיקה

הפונקציה מבצעת פעולה מעריכית ברמת הרכיב על טנסור operand ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: exp מ-IEEE-754.
  • למספרים מרוכבים: מעריכיים מרוכבים.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(exponential, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

דוגמאות נוספות

exponential_minus_one

סמנטיקה

הפונקציה מבצעת מעריכיות של רכיבים בקבוצה ומחסירה פעולה אחת על טנזור operand, ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: expm1 מ-IEEE-754.
  • למספרים מרוכבים: מספר מעריכי מרוכב פחות 1.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

דוגמאות נוספות

fft

סמנטיקה

מבצע את התמרות פורייה קדימה והיפוך, לקלט/פלט אמיתיים ומורכבים.

הערך של fft_type הוא אחד מהבאים:

  • FFT: העברת FFT מורכב למורכב.
  • IFFT: פונקציית FFT הפוכה, מורכבת למורכב.
  • RFFT: העברה של FFT אמיתית למורכב.
  • IRFFT: ערך FFT ההופכי מאמת למורכב (כלומר, לוקח מורכב, מחזיר ערך ממשי).

בצורה יותר רשמית, בהינתן הפונקציה fft שמקבלת טנזורים חד-ממדיים של סוגים מורכבים כקלט, מייצרת טנזורים חד-ממדיים מאותם סוגים כמו הפלט, ומחשבת את התמרת פורייה הנפרדת:

עבור fft_type = FFT, result מוגדר כתוצאה הסופית של סדרה של חישובי L, כאשר L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

בנוסף, בהינתן הפונקציה ifft שיש לה חתימה מסוג זהה, ומחשבת את ההופכי של fft:

עבור fft_type = IFFT, result מוגדר בתור ההופכי של החישובים עבור fft_type = FFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

בנוסף, בהינתן הפונקציה rfft שלוקחת tensors חד-ממדיים של סוגי נקודות צפות, מייצרת טנזורים חד-ממדיים של סוגים מורכבים של אותה סמנטיקה של נקודה צפה (floating-point) באופן הבא:

  • rfft(real_operand) = truncated_result כאשר
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(כאשר התמרת פורייה הנפרדת מחושבת לאופרנדים אמיתיים, N/2 + 1 האלמנטים הראשונים של התוצאה מגדירים באופן חד-משמעי את שאר התוצאה, כך שהתוצאה של rfft תיחתך כדי להימנע מחישוב רכיבים מיותרים).

עבור fft_type = RFFT, result מוגדר כתוצאה הסופית של סדרה של חישובי L, כאשר L = size(fft_length). לדוגמה, עבור L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

לבסוף, בהינתן הפונקציה irfft שיש לה חתימה מסוג זהה, ומחשבת את ההופכי של rfft:

עבור fft_type = IRFFT, result מוגדר בתור ההופכי של החישובים עבור fft_type = RFFT. לדוגמה, עבור L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor של נקודה צפה או סוג מורכב (C1), (C2), (C4), (C5)
(I2) fft_type טיפוסים בני מנייה (enum) של FFT, IFFT, RFFT ו-IRFFT (C2), (C5)
(I3) fft_length קבוע tensor חד-ממדי מסוג si64 (C1), (C3), (C4)

פלט

שם סוג מגבלות
result Tensor של נקודה צפה או סוג מורכב (C2), (C4), (C5)

מגבלות

  • (C1) size(fft_length) <= rank(operand).
  • (C2) הקשר בין סוגי הרכיבים operand ו-result משתנה:
    • אם fft_type = FFT, element_type(operand) ו-element_type(result) הם מאותו סוג מורכב.
    • אם fft_type = IFFT, element_type(operand) ו-element_type(result) הם מאותו סוג מורכב.
    • אם fft_type = RFFT, הערך element_type(operand) הוא מסוג נקודה צפה (floating-point) ו-element_type(result) הוא סוג מורכב של אותה סמנטיקה של נקודה צפה (floating-point).
    • אם הערך fft_type = IRFFT הוא מסוג element_type(operand), ו-element_type(result) הוא סוג של נקודה צפה (floating-point) של אותה סמנטיקה של נקודה צפה (floating-point).
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) אם בין operand ו-result, יש t tensor real מטיפוס של נקודה צפה, אז shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) למעט:
    • אם fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • אם fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

דוגמאות

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

פונקציית הרצפה

סמנטיקה

ביצוע של פונקציית הרצפה של operand טנסור והפקה של טנזור מסוג result. מטמיע את הפעולה roundToIntegralTowardNegative ממפרט IEEE-754. לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(floor, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או של t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

דוגמאות נוספות

לאסוף

סמנטיקה

איסוף מקטעים מ-tenor operand מהקיזוזים שצוינו ב-start_indices ומפיק tenor result.

בתרשים הבא מוצגים אופן המיפוי של אלמנטים ב-result על אלמנטים ב-operand באמצעות דוגמה קונקרטית. בתרשים נבחרים כמה דוגמאות של אינדקסים מסוג result ומסבירים בפירוט לאילו אינדקסים מסוג operand הם מתאימים.

בצורה יותר רשמית, result[result_index] = operand[operand_index] שבו:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • ההגדרה של start_index היא:
    • start_indices[bi0, ..., :, ..., biN] כאשר bi הם רכיבים נפרדים ב-batch_index וב-: מוכנס לאינדקס index_vector_dim, אם index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] אחרת.
  • עבור d_operand ב-axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) אם d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 אחרת.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] כאשר oi הם רכיבים בודדים ב-offset_index, ו-0 מוכנס לאינדקסים מתוך collapsed_slice_dims.
  • operand_index = full_start_index + full_offset_index.

אם הערך של indices_are_sorted הוא true, ההטמעה יכולה להניח שהערך start_indices ממוינים ביחס ל-start_index_map, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל הi1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensor מסוג מספר שלם (C2), (C3), (C13)
(I3) offset_dims קבוע tensor חד-ממדי מסוג si64 (C1), (C4-C5), (C13)
(I4) collapsed_slice_dims קבוע tensor חד-ממדי מסוג si64 (C1), (C6-C8), (C13)
(I5) start_index_map קבוע tensor חד-ממדי מסוג si64 (C3), (C9), (C10)
(I6) index_vector_dim קבוע מסוג si64 (C2), (C3), (C13)
(I7) slice_sizes קבוע tensor חד-ממדי מסוג si64 (C8) (C11-C13)
(I8) indices_are_sorted קבוע מסוג i1

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C5), (C13-C14)

מגבלות

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (ג2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) כאשר:
    • batch_dim_sizes = shape(start_indices) חוץ מהעובדה שגודל המאפיין start_indices שתואם ל-index_vector_dim לא נכלל.
    • offset_dim_sizes = shape(slice_sizes) מלבד גודלי המאפיינים ב-slice_sizes שתואמים ל-collapsed_slice_dims לא נכללים.
    • combine מציבה את batch_dim_sizes בצירים שתואמים ל-batch_dims ול-offset_dim_sizes בצירים שתואמים ל-offset_dims.
  • (C14) element_type(operand) = element_type(result).

דוגמאות

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

דוגמאות נוספות

get_dimension_size

סמנטיקה

הפונקציה מציגה את הגודל של dimension הנתון מתוך operand. בצורה יותר רשמית, result = dim(operand, dimension).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor (C1)
(I2) dimension קבוע מסוג si64 (C1)

פלט

שם סוג
result Tensor 0 ממדי מסוג si32

מגבלות

  • (C1) 0 <= dimension < rank(operand).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

דוגמאות נוספות

get_tuple_element

סמנטיקה

חילוץ רכיב במיקום index של ה-tuple operand ויוצר result. בצורה יותר רשמית, result = operand[index].

קלט

תווית שם סוג מגבלות
(I1) operand tuple (C1), (C2)
(I2) index קבוע מסוג si32 (C1), (C2)

פלט

שם סוג מגבלות
result כל סוג נתמך (C2)

מגבלות

  • (C1) 0 <= index < size(operand).
  • (ג2) type(result) = tuple_element_types(operand)[index].

דוגמאות

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

דוגמאות נוספות

if

סמנטיקה

הפונקציה יוצרת את הפלט מהפעלה של פונקציה אחת בדיוק מ-true_branch או false_branch, בהתאם לערך של pred. בצורה יותר רשמית, result = pred ? true_branch() : false_branch().

קלט

תווית שם סוג מגבלות
(I1) pred Tensor 0 ממדי מסוג i1
(I2) true_branch פונקציה (C1-C3)
(I3) false_branch פונקציה (C1), (C2)

פלט

שם סוג מגבלות
results מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C3)

מגבלות

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (ג2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

דוגמאות

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

דוגמאות נוספות

Imag

סמנטיקה

מחלצת את החלק המדומה, מבחינת רכיב, מה-operand ומפיקה טנזור result. באופן רשמי יותר, לכל רכיב x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor של נקודה צפה או סוג מורכב (C1), (C2)

פלט

שם סוג מגבלות
result Tensor מסוג נקודה צפה (floating-point) (C1), (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • element_type(operand) אחרת.

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

דוגמאות נוספות

בגוף הפיד

סמנטיקה

קריאת נתונים מהפיד ויוצרת results.

הסמנטיקה של infeed_config מוגדרת על ידי היישום.

results מורכב מערכי מטען ייעודי (payload) שמגיעים ראשונים, ואסימון הקודם. בעתיד, אנחנו מתכננים לפצל את המטען הייעודי (payload) והאסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג
(I1) token token
(I2) infeed_config קבוע מסוג string

פלט

שם סוג מגבלות
results מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C1-C3)

מגבלות

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) או is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

דוגמאות

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

דוגמאות נוספות

iota

סמנטיקה

מילוי טנזור output בערכים בסדר עולה, החל מאפס לאורך המאפיין iota_dimension. בצורה יותר רשמית,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

קלט

תווית שם סוג מגבלות
(I1) iota_dimension si64 (C1)

פלט

שם סוג מגבלות
output tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)

מגבלות

  • (C1) 0 <= iota_dimension < rank(output).

דוגמאות

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

דוגמאות נוספות

is_finite

סמנטיקה

הפונקציה בודקת אם הערך ב-x הוא סופי (כלומר, לא +Inf, -Inf, או NaN) ומפיקה טנזור y. מטמיע את הפעולה isFinite ממפרט IEEE-754. בסוגים שמחושבים כנתונים, התוצאה היא תמיד true.

קלט

תווית שם סוג מגבלות
(I1) x t e n s o r f l o w, או של t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
y tensor מסוג בוליאני (C1)

מגבלות

  • (C1) shape(x) = shape(y).

דוגמאות

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

דוגמאות נוספות

log

סמנטיקה

מבצע פעולת לוגריתמים ברמת הרכיב על טנסור operand ויוצר טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: log מ-IEEE-754.
  • למספרים מרוכבים: לוגריתם מרוכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(log, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

דוגמאות נוספות

log_plus_one

סמנטיקה

הפונקציה מבצעת לוגריתם של רכיבים מבחינת רכיב ופעולה אחת על הטנזור operand, ויוצרת tensor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: logp1 מ-IEEE-754.
  • למספרים מרוכבים: לוגריתם מרוכב ועוד 1.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(log_plus_one, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

דוגמאות נוספות

לוגיסטיקה

סמנטיקה

מבצע פעולה לוגיסטית ברמת הרכיבים ב-Tenor של operand ויוצר t tensor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: division(1, addition(1, exp(-x))) מ-IEEE-754.
  • במספרים מרוכבים: לוגיסטיקה מורכבת.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(logistic, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

דוגמאות נוספות

מפה

סמנטיקה

מחילה פונקציית מפה computation על inputs לאורך ה-dimensions ויוצרת טנזור result.

בצורה יותר רשמית, result[result_index] = computation(inputs...[result_index]). לתשומת ליבכם: הדומיין dimensions לא בשימוש כרגע, וסביר להניח שהוא יוסר בעתיד (#487).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1-C4)
(I2) dimensions קבוע tensor חד-ממדי מסוג si64 (C3)
(I3) computation פונקציה (C4)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1), (C4)

מגבלות

  • (C1) shape(inputs...) = shape(result).
  • (ג2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) ב-computation יש סוג (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> כאשר Ei = element_type(inputs[i]) ו-E' = element_type(result).

דוגמאות

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

דוגמאות נוספות

מקסימום

סמנטיקה

מבצע פעולה מקסימלית של הרכיבים ב-Tenors lhs ו-rhs, ומפיק את ה-result tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • עבור בוליאנים: OR לוגי.
  • למספרים שלמים: מספר שלם מקסימלי.
  • לצפים: maximum מ-IEEE-754.
  • למספרים מרוכבים: ערך מקסימלי מילולי לצמד (real, imaginary). יצירת סדר מספרים מרוכבים כרוכה בסמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה זו (#560).
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1)
(I2) rhs tensor, או tensor, quanted tensor (C1)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

דוגמאות נוספות

מינימום

סמנטיקה

מבצע פעולה מינימלית של רכיב מסוים על המנזרים lhs ו-rhs ומפיק את ה-result tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • בשביל בוליאנים: logic AND.
  • למספרים שלמים: מינימום של מספר שלם.
  • לצפים: minimum מ-IEEE-754.
  • למספרים מרוכבים: המינימום המילולי לצמד (real, imaginary). יצירת סדר מספרים מרוכבים כרוכה בסמנטיקה מפתיעה, ולכן בעתיד אנחנו מתכננים להסיר את התמיכה במספרים מרוכבים בפעולה זו (#560).
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1)
(I2) rhs tensor, או tensor, quanted tensor (C1)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

דוגמאות נוספות

הכפלה

סמנטיקה

הפונקציה מבצעת את המכפלה של שני tensor (רכיב) מסוג lhs ו-rhs, ויוצרת את result tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • בשביל בוליאנים: logic AND.
  • למספרים שלמים: כפל של מספרים שלמים.
  • לצפים: multiplication מ-IEEE-754.
  • למספרים מרוכבים: כפל מרוכב.
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs tensor, או tensor, quanted tensor (C1)
(I2) rhs tensor, או tensor, quanted tensor (C1)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

דוגמאות נוספות

שלילה

סמנטיקה

מבצע שלילה ברמת הרכיב של טנסור operand ויוצר tensor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים וחתומים: שלילה של מספרים שלמים.
  • למספרים שלמים לא חתומים: bitcast למספר שלם עם עקבות, שלילה של מספרים שלמים, bitcast בחזרה למספר שלם לא מסומן.
  • לצפים: negate מ-IEEE-754.
  • למספרים מרוכבים: שלילה מורכבת.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(negate, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

דוגמאות נוספות

לא

סמנטיקה

מבצע את הפונקציה של NOT של tensor operand ויוצר tenor של result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • עבור בוליאנים: logic NOT.
  • למספרים שלמים: NOT.

ארגומנטים

שם סוג מגבלות
operand tensor מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

סמנטיקה

המדיניות מוודאת שהפעולות שמייצרות את operand יבוצעו לפני פעולות שתלויות ב-result, ומונעת מטרנספורמציות מהדר להעביר פעולות על המחסום. מלבד זאת, הפעולה היא זהות, כלומר result = operand.

ארגומנטים

שם סוג מגבלות
operand מספר וריאדי של טנזורים, טנזורים או אסימונים מכווצים לכל טנזור (C1)

פלט

שם סוג מגבלות
result מספר וריאדי של טנזורים, טנזורים או אסימונים מכווצים לכל טנזור (C1)

מגבלות

  • (C1) type(operand...) = type(result...).

דוגמאות

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

דוגמאות נוספות

או

סמנטיקה

הפעולה של OR ברמת הרכיב של שני טנזורים lhs ו-rhs ויוצרת result של img_tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • עבור בוליאנים: OR לוגי.
  • למספרים שלמים: OR ברמת הסיביות.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor של מספר שלם או סוג בוליאני (C1)
(I2) rhs tensor של מספר שלם או סוג בוליאני (C1)

פלט

שם סוג מגבלות
result tensor של מספר שלם או סוג בוליאני (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

מודעות Outstream

סמנטיקה

המערכת כותבת את הערך inputs בפיד היוצא ויוצרת אסימון result.

הסמנטיקה של outfeed_config מוגדרת על ידי היישום.

קלט

תווית שם סוג
(I1) inputs מספר וריאנטים של טנזורים או מותנים קוונטיים
(I2) token token
(I3) outfeed_config קבוע מסוג string

פלט

שם סוג
result token

דוגמאות

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

דוגמאות נוספות

רפידה

סמנטיקה

מרחיב את operand על ידי מרווח מסביב לממתח וגם בין רכיבי ה-Tenor עם padding_value הנתון.

edge_padding_low ו-edge_padding_high מציינים את כמות המרווח הפנימי שנוספה ברמה הנמוכה (לצד אינדקס 0) ובחלק העליון (לצד האינדקס הגבוה ביותר) של כל מאפיין, בהתאמה. מידת המרווח הפנימי יכולה להיות שלילית, כאשר הערך המוחלט של מרווח פנימי שלילי מציין את מספר הרכיבים שיש להסיר מהמאפיין שצוין.

interior_padding מציין את כמות המרווח הפנימי שנוסף בין שני רכיבים בכל מאפיין, שאינה שלילית. מרווח פנימי פנימי נוצר לפני המרווח הפנימי בקצוות, כך שהמרווח הפנימי בקצוות מסיר אלמנטים מהאופרנד שמרופד בפנים.

באופן רשמי יותר, ההגדרה result[result_index] היא:

  • operand[operand_index] אם result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C2), (C4)
(I2) padding_value img_tensor, tensor, penanty (C1)
(I3) edge_padding_low קבוע tensor חד-ממדי מסוג si64 (C1), (C4)
(I4) edge_padding_high קבוע tensor חד-ממדי מסוג si64 (C1), (C4)
(I5) interior_padding קבוע tensor חד-ממדי מסוג si64 (C2-C4)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C3-C6)

מגבלות

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (ג2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

דוגמאות

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

דוגמאות נוספות

partition_id

סמנטיקה

הפונקציה יוצרת partition_id מהתהליך הנוכחי.

פלט

שם סוג
result Tensor 0 ממדי מסוג ui32

דוגמאות

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

דוגמאות נוספות

Popcnt

סמנטיקה

הפונקציה מבצעת ספירה של מספר הסיביות ב-Tenor של operand ברמת הרכיב, ויוצרת טנסור result.

קלט

תווית שם סוג מגבלות
(I1) operand tensor מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג מספר שלם (C1)

מגבלות

  • (C1) type(operand) = type(result).

דוגמאות

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

דוגמאות נוספות

כוח

סמנטיקה

הפונקציה מבצעת הגדלה ברמת הרכיב של טנסור lhs על ידי טנסור rhs ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • במספרים שלמים: הגדלה של מספרים שלמים.
  • לצפים: pow מ-IEEE-754.
  • למספרים מרוכבים: העלאה בחזקה מרוכבת.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(power, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)
(I2) rhs t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

דוגמאות נוספות

ריאל

סמנטיקה

שולפת את החלק האמיתי, מבחינת הרכיבים, מה-operand ויוצרת tensor result. באופן רשמי יותר, לכל רכיב x: real(x) = is_complex(x) ? real_part(x) : x.

קלט

תווית שם סוג מגבלות
(I1) operand Tensor של נקודה צפה או סוג מורכב (C1), (C2)

פלט

שם סוג מגבלות
result Tensor מסוג נקודה צפה (floating-point) (C1), (C2)

מגבלות

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) מוגדר כך:
    • complex_element_type(element_type(operand)) אם is_complex(operand).
    • element_type(operand) אחרת.

דוגמאות

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

דוגמאות נוספות

תקליטים

סמנטיקה

הפונקציה מקבלת נתונים מערוץ עם channel_id ומפיקה results.

אם הערך של is_host_transfer הוא true, הפעולה תעביר נתונים מהמארח. אחרת, הוא מעביר נתונים ממכשיר אחר. המשמעות היא שההטמעה מוגדרת. הסימון הזה משכפל את המידע שסופק ב-channel_type, כך שבעתיד אנחנו מתכננים לשמור רק אחד מהם (#666).

results מורכב מערכי מטען ייעודי (payload) שמגיעים ראשונים, ואסימון הקודם. בעתיד, אנחנו מתכננים לפצל את המטען הייעודי (payload) והאסימון לשני פלטים נפרדים כדי לשפר את הבהירות (#670).

קלט

תווית שם סוג מגבלות
(I1) token token (C4)
(I2) channel_id קבוע מסוג si64
(I3) channel_type טיפוסים בני מנייה (enum) של DEVICE_TO_DEVICE ושל HOST_TO_DEVICE (C1)
(I4) is_host_transfer קבוע מסוג i1 (C1)

פלט

שם סוג מגבלות
results מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C2-C4)

מגבלות

  • (C1) channel_type מוגדר כך:
    • HOST_TO_DEVICE אם is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.
  • (ג2) 0 < size(results).
  • (C3) is_empty(result[:-1]) או is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

דוגמאות

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

דוגמאות נוספות

הקטנה

סמנטיקה

מחילה את פונקציית ההפחתה body על inputs ו-init_values לאורך dimensions ויוצרת results טנזורים.

סדר ההפחתה מוגדר לפי הטמעה, לכן body ו-init_values חייבים ליצור מונואיד כדי להבטיח שהפעולה תניב את אותן תוצאות עבור כל ערכי הקלט בכל ההטמעות. עם זאת, המצב הזה לא תקף במקרים רבים של הפחתה פופולרית. לדוגמה, הוספה של נקודה צפה (floating-point) עבור body ו-0 עבור init_values לא יוצרת בפועל מונואיד, כי חיבור של נקודה צפה (floating-point) לא אסוציאטיבי.

בצורה יותר רשמית, results...[j0, ..., jR-1] = reduce(input_slices_converted) שבו:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], כאשר : מוכנסות ב-dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) לעץ בינארי מסוים schedule שבו:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule הוא עץ בינארי מלא שמוגדר על ידי היישום, שהמעבר שלו לפי סדר כולל:
    • input_slices_converted...[index], לכל ה-index ב-index_space(input_slices_converted) בסדר המילולי עולה של index.
    • אינטראקציה עם סכום מוגדר של init_values_converted במיקומים שהוגדרו על ידי ההטמעה.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1-C4), (C6), (C7)
(I2) init_values מספר הווריאנטים של טנזורים חד-ממדיים או מותנים מכניים לכל טנזור (C2), (C3)
(I3) dimensions קבוע tensor חד-ממדי מסוג si64 (C4), (C5), (C7)
(I4) body פונקציה (C6)

פלט

שם סוג מגבלות
results מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C3), (C7), (C8)

מגבלות

  • (C1) same(shape(inputs...)).
  • (ג2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) ב-body יש סוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...) אלא שגודלי המאפיינים inputs... שתואמים ל-dimensions לא נכללים.
  • (C8) element_type(results[i]) = Ei לכל i שב-[0,N).

דוגמאות

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

דוגמאות נוספות

reduce_precision

סמנטיקה

הפונקציה מבצעת המרה של operand עבור רכיבים שונים לסוג אחר של נקודה צפה שמשתמשת ב-exponent_bits וב-mantissa_bits ובחזרה לסוג המקורי של הנקודה הצפה, ויוצרת את ה-tenor output.

בצורה יותר רשמית:

  • ביטים של המנטיזה של הערך המקורי מתעדכנים כדי לעגל את הערך המקורי לערך הקרוב ביותר שאפשר לייצג באמצעות mantissa_bits באמצעות סמנטיקה של roundToIntegralTiesToEven.
  • לאחר מכן, אם mantissa_bits קטן ממספר הביטים של המטריסטה של הערך המקורי, ביטים של המטריסה נחתכים ל-mantissa_bits.
  • לאחר מכן, אם הסיביות המעריכיות של תוצאת הביניים לא מתאימות לטווח שסופק על ידי exponent_bits, תוצאת הביניים גולשת לאינסוף באמצעות הסימן המקורי, או מתרחקת מתחת לאפס באמצעות הסימן המקורי.
  • לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)
(I2) exponent_bits קבוע מסוג si32 (C2)
(I3) mantissa_bits קבוע מסוג si32 (C3)

פלט

שם סוג מגבלות
output t e n s o r f l o w, או של t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(output).
  • (ג2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

דוגמאות

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

דוגמאות נוספות

reduce_scatter

סמנטיקה

בתוך כל קבוצת תהליכים ברשת התהליך של StableHLO, הפונקציה מבצעת הפחתה באמצעות computations, על הערכים של ה-tensor operand בכל תהליך, מפצלת את תוצאת ההפחתה לאורך scatter_dimension לחלקים ומפזרת את החלקים המפוצלים בין התהליכים כדי לייצר את ה-result.

הפעולה תפצל את רשת התהליך ב-StableHLO ל-process_groups, שמוגדר כך:

  • cross_replica(replica_groups) אם channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) אם channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) אם channel_id > 0 and use_global_device_ids = true.

לאחר מכן, בכל process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] לכל ה-sender ב-process_group, כאשר receiver_index = process_group.index(receiver).

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension קבוע מסוג si64 (C1), (C2), (C8)
(I3) replica_groups קבוע tensor דו-ממדי מסוג si64 (C3-C5)
(I4) channel_id קבוע מסוג si64 (C6)
(I5) use_global_device_ids קבוע מסוג i1 (C6)
(I6) computation פונקציה (C7)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C8-C9)

מגבלות

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (ג2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) הערך size(replica_groups) מוגדר כך:
    • num_replicas אם נעשה שימוש בcross_replica.
    • num_replicas אם נעשה שימוש בcross_replica_and_partition.
    • num_processes אם נעשה שימוש בflattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) אם use_global_device_ids = true אז channel_id > 0.
  • (C7) ב-computation יש סוג (tensor<E>, tensor<E>) -> (tensor<E>) כאשר is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) למעט:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

דוגמאות

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

דוגמאות נוספות

reduce_window

סמנטיקה

מחילה את פונקציית ההפחתה body על חלונות של inputs ו-init_values ויוצרת results.

התרשים הבא מראה כיצד מחושבים הרכיבים ב-results... מ-inputs... באמצעות דוגמה קונקרטית.

באופן רשמי יותר, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (אפשר לעיין בקטע צמצום) כאשר:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values מספר הווריאנטים של טנזורים חד-ממדיים או מותנים מכניים לכל טנזור (C1), (C13)
(I3) window_dimensions קבוע tensor חד-ממדי מסוג si64 (C4), (C5), (C15)
(I4) window_strides קבוע tensor חד-ממדי מסוג si64 (C6), (C7), (C15)
(I5) base_dilations קבוע tensor חד-ממדי מסוג si64 (C8) (C9), (C15)
(I6) window_dilations קבוע tensor חד-ממדי מסוג si64 (C10), (C11), (C15)
(I7) padding קבוע tensor דו-ממדי מסוג si64 (C12), (C15)
(I8) body פונקציה (C13)

פלט

שם סוג מגבלות
results מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1), (C14-C16)

מגבלות

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (ג2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) ב-body יש סוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) כאשר is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows כאשר:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei לכל ה-i ב-[0,N).

דוגמאות

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

דוגמאות נוספות

שארית

סמנטיקה

מבצע את שארית המחולק lhs ואת המחלק rhs של t tensor, ויוצר result t tensor.

באופן רשמי יותר, הסימן של התוצאה נלקח מהדיבידנד, והערך המוחלט של התוצאה תמיד קטן מהערך המוחלט של המחלק. השארית מחושבת כ-lhs - d * rhs, כאשר d ניתנת על ידי:

  • למספרים שלמים: stablehlo.divide(lhs, rhs).
  • במקרה של צף: division(lhs, rhs) מ-IEEE-754 עם מאפיין עיגול המספרים roundTowardZero.
  • למספרים מרוכבים: TBD(#997).
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

באלמנטים של נקודה צפה (floating-point), הפעולה הזו מנוגדת לפעולה remainder ממפרט IEEE-754, שבה d הוא ערך אינטגרלי שקרוב ביותר לערך המדויק של lhs/rhs עם קשר לקשר זוגי.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)
(I2) rhs tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)

פלט

שם סוג מגבלות
result tensor של מספר שלם, נקודה צפה (floating-point) או סוג מורכב, או img_tensor p-tensor (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

דוגמאות נוספות

replica_id

סמנטיקה

הפונקציה יוצרת replica_id מהתהליך הנוכחי.

פלט

שם סוג
result Tensor 0 ממדי מסוג ui32

דוגמאות

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

דוגמאות נוספות

לשנות את הצורה

סמנטיקה

מבצע שינוי של טנזור operand לטנזור result. בעיקרון, כדאי לשמור על אותו ייצוג קנוני אבל אפשר לשנות את הצורה, למשל מ-tensor<2x3xf32> ל-tensor<3x2xf32> או ל-tensor<6xf32>.

באופן רשמי יותר, result[result_index] = operand[operand_index] כאשר ל-result_index ול-operand_index יש את אותו המיקום בסדר המילוני של index_space(result) ו-index_space(operand).

קלט

תווית שם סוג מגבלות
(I1) operand את img_tensor, או tensor, pantor (C1-C3)

פלט

שם סוג מגבלות
result את img_tensor, או tensor, pantor (C1-C3)

מגבלות

  • (C1) הערך element_type(result) ניתן על ידי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) אלא אם כן quantization_dimension(operand) ו-quantization_dimension(result) עשויים להיות שונים.
  • (ג2) size(operand) = size(result).
  • (C3) אם is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

דוגמאות

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

דוגמאות נוספות

הפוך

סמנטיקה

הפונקציה הופכת את סדר הרכיבים ב-operand לאורך ה-dimensions שצוין ויוצרת טנזור result. בצורה רשמית יותר, result[result_index] = operand[operand_index] כאשר:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 אם d ב-dimensions.
  • operand_index[d] = result_index[d] אחרת.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1), (C3)
(I2) dimensions קבוע tensor חד-ממדי מסוג si64 (C2), (C3)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1), (C3)

מגבלות

  • (C1) type(operand) = type(result).
  • (ג2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

דוגמאות

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

דוגמאות נוספות

ארנג

סמנטיקה

הפונקציה יוצרת מספרים אקראיים באמצעות האלגוריתם rng_distribution ויוצרת טנזור result של צורה נתונה shape.

אם הערך הוא rng_distribution = UNIFORM, המספרים האקראיים נוצרים לפי ההתפלגות האחידה בקטע [a, b). אם הערך הוא a >= b, ההתנהגות לא מוגדרת.

אם הערך הוא rng_distribution = NORMAL, המספרים האקראיים נוצרים לפי ההתפלגות הנורמלית, כשהממוצע = a וסטיית התקן היא b. אם הערך הוא b < 0, ההתנהגות לא מוגדרת.

אופן היצירה המדויק של מספרים אקראיים מוגדר על ידי היישום. לדוגמה, יכול להיות שהם יהיו דטרמיניסטיים ויכול להיות שלא, והם ישתמשו במצב מוסתר או לא.

בדיונים עם בעלי עניין רבים, האפשרות הזו הוצאה משימוש באותה מידה, כך שבעתיד אנחנו מתכננים להסיר אותה (#597).

קלט

תווית שם סוג מגבלות
(I1) a Tensor 0-ממדי של מספר שלם, סוג בוליאני או נקודה צפה (floating-point) (C1), (C2)
(I2) b Tensor 0-ממדי של מספר שלם, סוג בוליאני או נקודה צפה (floating-point) (C1), (C2)
(I3) shape קבוע tensor חד-ממדי מסוג si64 (C3)
(I4) rng_distribution טיפוסים בני מנייה (enum) של UNIFORM ושל NORMAL (C2)

פלט

שם סוג מגבלות
result Tensor של מספר שלם, סוג בוליאני או נקודה צפה (floating-point) (C1-C3)

מגבלות

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) אם הערך הוא rng_distribution = NORMAL, אז is_float(a).
  • (C3) shape(result) = shape.

דוגמאות

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

סמנטיקה

הפונקציה מחזירה output שמלאה בביטים אקראיים אחידים ומצב פלט מעודכן output_state באמצעות האלגוריתם של מחולל המספרים הפסאודו אקראי rng_algorithm במצב ראשוני initial_state. הפלט מובטח להיות פונקציה דטרמיניסטית של initial_state, אבל לא מובטח שהוא יהיה חד-משמעי בין הטמעות.

הערך של rng_algorithm הוא אחד מהבאים:

  • DEFAULT: אלגוריתם מוגדר על ידי הטמעה.
  • THREE_FRY: וריאנט שהוגדר על ידי היישום של אלגוריתם Threefry.*
  • PHILOX: וריאנט שהוגדר על ידי היישום של אלגוריתם Philox.*

* למידע נוסף: Salmon et al. SC 2011. מספרים אקראיים מקבילים: קלים כמו 1, 2, 3.

קלט

תווית שם סוג מגבלות
(I1) rng_algorithm טיפוסים בני מנייה (enum) של DEFAULT, של THREE_FRY ושל PHILOX (C2)
(I2) initial_state img_tensor חד-ממדי מסוג ui64 (C1), (C2)

פלט

שם סוג מגבלות
output_state img_tensor חד-ממדי מסוג ui64 (C1)
output Tensor של מספר שלם או מסוג נקודה צפה (floating-point)

מגבלות

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) מוגדר כך:
    • מוגדר אם rng_algorithm = DEFAULT.
    • 2 אם rng_algorithm = THREE_FRY.
    • 2 או 3 אם rng_algorithm = PHILOX.

דוגמאות

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

סמנטיקה

הפונקציה הזו מבצעת עיגול של הרכיבים כלפי המספר השלם הקרוב ביותר, וקוטעת קשרים מאפס, על הטנזור operand ויוצרת טנזור result. מטמיעה את הפעולה roundToIntegralTiesToAway ממפרט IEEE-754. לסוגים שמבוססים על קוונטים, הפונקציה מבצעת dequantize_op_quantize(round_nearest_afz, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או של t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

דוגמאות נוספות

round_nearest_even

סמנטיקה

הפונקציה מבצעת עיגול חכם של הרכיבים כלפי המספר השלם הקרוב ביותר, מקשרת את הקשר כלפי המספר השלם הזוגי ב-Tenor operand ויוצרת טנזור result. מטמיע את הפעולה roundToIntegralTiesToEven ממפרט IEEE-754. לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(round_nearest_even, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או של t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או של t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

דוגמאות נוספות

RSqrt

סמנטיקה

מבצעת פעולת שורש ריבועית הדדית של רכיבים ב-Tenor של operand, ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: rSqrt מ-IEEE-754.
  • למספרים מרוכבים: שורש ריבועי הפוך מורכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(rsqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

דוגמאות נוספות

scatter

סמנטיקה

הפונקציה מפיקה results טנזורים השווה ל-inputs tenors, למעט שחלק מהמקטעים שצוינו ב-scatter_indices מתעדכנים בערכים updates באמצעות update_computation.

בתרשים הבא מוצגים אופן המיפוי של אלמנטים ב-updates... על אלמנטים ב-results... באמצעות דוגמה קונקרטית. בתרשים נבחרים כמה דוגמאות של אינדקסים מסוג updates..., ומסבירים בפירוט לאילו סוגי אינדקסים של results... הם מתאימים.

באופן רשמי יותר, לכל הupdate_index בindex_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • ההגדרה של start_index היא:
    • scatter_indices[si0, ..., :, ..., siN] כאשר si הם רכיבים נפרדים ב-update_scatter_index ו-: מוכנס לאינדקס index_vector_dim, אם index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] אחרת.
  • עבור d_input ב-axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] אם d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 אחרת.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] כאשר wi הם רכיבים בודדים ב-update_window_index, ו-0 מוכנס לאינדקסים מתוך inserted_window_dims.
  • result_index = full_start_index + full_window_index.

בהתאם לכך, results = exec(schedule, inputs), כאשר:

  • schedule הוא פרמוטציה של index_space(updates[0]) שמוגדרת על ידי ההטמעה.
  • exec([update_index, ...], results) = exec([...], updated_results) כאשר:
    • אם result_index בגבולות של shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results הוא עותק של results עם הערך results...[result_index] שמוגדר ל-updated_values....
    • אחרת
    • updated_results = results.
  • exec([], results) = results.

אם הערך של indices_are_sorted הוא true, ההטמעה יכולה להניח שהפונקציה scatter_indices ממוינת ביחס ל-scatter_dims_to_operand_dims, אחרת ההתנהגות לא מוגדרת. באופן רשמי יותר, לכל הi1 < i2 מ-indices(result), full_start_index(i1) <= full_start_index(i2).

אם unique_indices הוא true, ההטמעה יכולה להניח שכל האינדקסים של result_index שמפוזרים אליהם הם ייחודיים. אם unique_indices הוא true אבל האינדקסים שמפוזרים אליהם לא ייחודיים, ההתנהגות לא מוגדרת.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1), (C2), (C4-C6), (C10), (C13), (C15-C16)
(I2) scatter_indices tensor מסוג מספר שלם (C4), (C11), (C14)
(I3) updates מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C3-C6), (C8)
(I4) update_window_dims קבוע tensor חד-ממדי מסוג si64 (C2), (C4), (C7), (C8)
(I5) inserted_window_dims קבוע tensor חד-ממדי מסוג si64 (C2), (C4), (C9), (C10)
(I6) scatter_dims_to_operand_dims קבוע tensor חד-ממדי מסוג si64 (C11-C13)
(I7) index_vector_dim קבוע מסוג si64 (C4), (C11), (C14)
(I8) indices_are_sorted קבוע מסוג i1
(I9) unique_indices קבוע מסוג i1
(I10) update_computation פונקציה (C15)

פלט

שם סוג מגבלות
results מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C15-C17)

מגבלות

  • (C1) same(shape(inputs...)).
  • (ג2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) כאשר:
    • update_scatter_dim_sizes = shape(scatter_indices) חוץ מהעובדה שגודל המאפיין scatter_indices שתואם לערך index_vector_dim לא נכלל.
    • update_window_dim_sizes <= shape(inputs[0]) חוץ מהמידות של המאפיינים inputs[0] שתואמים ל-inserted_window_dims, לא נכללים.
    • combine מציבה את update_scatter_dim_sizes בצירים שתואמים ל-update_scatter_dims ול-update_window_dim_sizes בצירים שתואמים ל-update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) ל-update_computation יש סוג (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), וכאשר is_promotable(element_type(inputs[i]), Ei).
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei לכל i שב-[0,N).

דוגמאות

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

דוגמאות נוספות

בחירה

סמנטיקה

הפונקציה יוצרת טנטור result כאשר כל רכיב נבחר מ-on_true או מהטנזור on_false על סמך הערך של הרכיב התואם של pred. באופן רשמי יותר, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], איפה pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. לסוגים שמחושבים לפי כמות, הביצועים של dequantize_select_quantize(pred, on_true, on_false, type(result)).

קלט

תווית שם סוג מגבלות
(I1) pred tensor מסוג i1 (C1)
(I2) on_true tensor, או tensor, quanted tensor (C1-C2)
(I3) on_false tensor, או tensor, quanted tensor (C2)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C2)

מגבלות

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (ג2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

דוגמאות

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

דוגמאות נוספות

select_and_scatter

סמנטיקה

פיזור הערכים מהטנזור source באמצעות scatter על סמך התוצאה של reduce_window של ה-Tenor input באמצעות select, ומפיק טנזור result.

התרשים הבא מראה כיצד מחושבים הרכיבים ב-result מ-operand ומ-source באמצעות דוגמה קונקרטית.

בצורה יותר רשמית:

  • selected_values = reduce_window_without_init(...) עם רכיבי הקלט הבאים:

    • 'inputs = [operand].
    • window_dimensions, window_strides ו-padding, שנעשה בהם שימוש כפי שהם.
    • base_dilations = windows_dilations = 1.
    • ההגדרה של body היא:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    כאשר E = element_type(operand) ו-reduce_window_without_init פועלים בדיוק כמו reduce_window, אלא שה-schedule של הערך הבסיסי reduce (מידע נוסף זמין במאמר צמצום) לא כולל ערכים ראשוניים. נכון לעכשיו, לא ברור מה יקרה אם לא יהיו ערכים לחלון המתאים (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) איפה:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index אם selected_values[source_index] מכיל את הרכיב operand מ-operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1-C4), (C6), (C8-C11)
(I2) source tensor, או tensor, quanted tensor (C1), (C2)
(I3) init_value img_tensor, tensor, penanty (C3)
(I4) window_dimensions קבוע tensor חד-ממדי מסוג si64 (C2), (C4), (C5)
(I5) window_strides קבוע tensor חד-ממדי מסוג si64 (C2), (C6), (C7)
(I6) padding קבוע tensor דו-ממדי מסוג si64 (C2), (C8)
(I7) select פונקציה (C9)
(I8) scatter פונקציה (C10)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C11-C12)

מגבלות

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows כאשר:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) ב-select יש סוג (tensor<E>, tensor<E>) -> tensor<i1> כאשר E = element_type(operand).
  • (C10) scatter מכיל סוג (tensor<E>, tensor<E>) -> tensor<E> כאשר is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

דוגמאות

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

דוגמאות נוספות

שליחה

סמנטיקה

שליחת inputs לערוץ channel_id ויוצרת אסימון result.

אם הערך של is_host_transfer הוא true, הפעולה תעביר נתונים למארח. אחרת, הוא מעביר נתונים למכשיר אחר. המשמעות היא שההטמעה מוגדרת. הסימון הזה משכפל את המידע שסופק ב-channel_type, כך שבעתיד אנחנו מתכננים לשמור רק אחד מהם (#666).

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים של טנזורים או מותנים קוונטיים
(I2) token token
(I3) channel_id קבוע מסוג si64
(I4) channel_type טיפוסים בני מנייה (enum) של DEVICE_TO_DEVICE ושל DEVICE_TO_HOST (C1)
(I5) is_host_transfer קבוע מסוג i1 (C1)

פלט

שם סוג
result token

מגבלות

  • (C1) channel_type מוגדר כך:
    • DEVICE_TO_HOST אם is_host_transfer = true,
    • DEVICE_TO_DEVICE אחרת.

דוגמאות

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

דוגמאות נוספות

shift_left

סמנטיקה

מבצע פעולת היסט שמאלית של הרכיב ברמת הרכיבים על הטנזור lhs באמצעות מספר rhs של הסיביות, ומפיק טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג מספר שלם (C1)
(I2) rhs tensor מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

דוגמאות נוספות

shift_right_arithmetic

סמנטיקה

מבצע פעולת היסט אריתמטי של היסט ימינה בכל הרכיבים של הטנזור lhs ב-rhs מספר הביטים, ומפיק טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג מספר שלם (C1)
(I2) rhs tensor מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

דוגמאות נוספות

shift_right_logical

סמנטיקה

מבצע פעולת היסט לוגית לוגית ברמת הרכיב על הטנזור lhs לפי מספר הביטים rhs ומפיק טנזור result.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג מספר שלם (C1)
(I2) rhs tensor מסוג מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

דוגמאות נוספות

סמל

סמנטיקה

מחזירה את הסימן של operand ברמת הרכיב ומפיקה טנזור result. באופן רשמי יותר, לכל רכיב x ניתן לבטא את הסמנטיקה באמצעות תחביר Python, באופן הבא:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(sign, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, t e n s o d s l o re ג, או ר בל ט ר נו ה בשביל ת מור ט אן ט 1 ט 1 ש 1111111111111111 כבר{/1} ה ה (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, t e n s o d s l o re ג, או ר בל ט ר נו ה בשביל ת מור ט אן ט 1 ט 1 ש 1111111111111111 כבר{/1} ה ה (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

דוגמאות נוספות

סינוס

סמנטיקה

מבצע פעולת סינוס של רכיב מסוים על טנסור operand ויוצר tensor result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: sin מ-IEEE-754.
  • למספרים מרוכבים: סינוס מרוכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(sine, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

דוגמאות נוספות

פרוסה (slice)

סמנטיקה

מוציאה פרוסה מה-operand באמצעות אינדקסים ראשוניים שמחושבים באופן סטטי ויוצרת טנזור result. start_indices מכיל את האינדקסים המתחילים של הפרוסה בכל מאפיין, limit_indices מכיל את אינדקסי הסיום (לא כולל) של הפלח של כל מאפיין, ו-strides מכיל את השלבים בכל מאפיין.

בצורה יותר רשמית, result[result_index] = operand[operand_index] איפה operand_index = start_indices + result_index * strides.

קלט

תווית שם סוג מגבלות
(I1) operand tensor, או tensor, quanted tensor (C1-C3), (C5)
(I2) start_indices קבוע tensor חד-ממדי מסוג si64 (C2), (C3), (C5)
(I3) limit_indices קבוע tensor חד-ממדי מסוג si64 (C2), (C3), (C5)
(I4) strides קבוע tensor חד-ממדי מסוג si64 (C2), (C4)

פלט

שם סוג מגבלות
result tensor, או tensor, quanted tensor (C1), (C5)

מגבלות

  • (C1) element_type(operand) = element_type(result).
  • (ג2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

דוגמאות

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

דוגמאות נוספות

מיון

סמנטיקה

ממיינת יחד מקטעים חד-ממדיים של inputs לאורך המאפיין dimension, לפי comparator ומפיקה results.

בניגוד לקלט דומה בפעולות אחרות, הפונקציה dimension מאפשרת ערכים שליליים, בהתאם לסמנטיקה שמתוארת בהמשך. בעתיד, ייתכן שהדבר ייאסר מסיבות של עקביות (#1377).

אם הערך של is_stable הוא TRUE, אז המיון יציב, כלומר הסדר היחסי של רכיבים שנחשבים כשווים על ידי המשווה נשמר. במקרה שבו יש קלט אחד, שני הרכיבים e1 ו-e2 נחשבים כשווים על ידי המשווה אם ורק אם comparator(e1, e2) = comparator(e2, e1) = false. בהמשך מוסבר איך התכונה הזו חלה על ערכי קלט מרובים.

באופן רשמי יותר, לכל הresult_index בindex_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] כאשר riN הם רכיבים בודדים ב-result_index, ו-: מוכנסת במקום adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • כאשר sort ממיינת פרוסה חד-ממדית בסדר יורד, מתוך ציפייה ש-comparator_together תחזיר true אם הארגומנט בצד שמאל קטן מהארגומנט בצד ימין.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

קלט

תווית שם סוג מגבלות
(I1) inputs מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C1-C5)
(I2) dimension קבוע מסוג si64 (C4)
(I3) is_stable קבוע מסוג i1
(I4) comparator פונקציה (C5)

פלט

שם סוג מגבלות
results מספר וריאנטים (tensor) של tensor או tensor, p-tensor, (C2), (C3)

מגבלות

  • (C1) 0 < size(inputs).
  • (ג2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, כאשר R = rank(inputs[0]).
  • (C5) ב-comparator יש סוג (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, וכאשר Ei = element_type(inputs[i]).

דוגמאות

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

דוגמאות נוספות

sqrt

סמנטיקה

מבצעת פעולת שורש ריבועית ברמת הרכיב על טנסור operand ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: squareRoot מ-IEEE-754.
  • למספרים מרוכבים: שורש ריבועי מורכב.
  • לסוגים שמבוססים על כמות: dequantize_op_quantize(sqrt, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

דוגמאות נוספות

חיסור

סמנטיקה

מבצעת פעולת חיסור של שני טנזרים ברמת הרכיב lhs ו-rhs ויוצרת טנזור result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • למספרים שלמים: חיסור של מספרים שלמים.
  • לצפים: subtraction מ-IEEE-754.
  • למספרים מרוכבים: פעולת חיסור מרוכבת.
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

קלט

תווית שם סוג מגבלות
(I1) lhs t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)
(I2) rhs t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, t e n s o י , או ב ר ח מ מ מור, ב ר 1 - ט נזור ט 1 1 1 1 1 1 1 m או ב ת יש ט en t בלי ט נזור (C1)

מגבלות

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

דוגמאות

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

דוגמאות נוספות

טן

סמנטיקה

מבצעת פעולה של טנגנס היפרבולי ברמת הרכיב על טנגנס operand ויוצרת טנגנס result. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לצפים: tanh מ-IEEE-754.
  • למספרים מרוכבים: טנגנס היפרבולי מורכב.
  • בסוגים שמבוססים על כמות:
    • dequantize_op_quantize(tanh, operand, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_type(operand) = baseline_type(result).

דוגמאות

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

דוגמאות נוספות

להחליף

סמנטיקה

משנה את המידות של טנסור operand באמצעות permutation ויוצרת טנזור result. בצורה יותר רשמית, result[result_index] = operand[operand_index] תוך result_index[d] = operand_index[permutation[d]].

קלט

תווית שם סוג מגבלות
(I1) operand את img_tensor, או tensor, pantor (C1-C4)
(I2) permutation קבוע tensor חד-ממדי מסוג si64 (C2-C4)

פלט

שם סוג מגבלות
result את img_tensor, או tensor, pantor (C1), (C3-C4)

מגבלות

  • (C1) הערך element_type(result) ניתן על ידי:
    • element_type(operand), אם !is_per_axis_quantized(operand).
    • element_type(operand) אלא אם כן quantization_dimension(operand) ו-quantization_dimension(result) עשויים להיות שונים.
  • (C2) permutation הוא תמורה של range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) אם הערך הוא is_per_axis_quantized(result), אז quantization_dimension(operand) = permutation(quantization_dimension(result)).

דוגמאות

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

דוגמאות נוספות

triangular_solve

סמנטיקה

פתרון קבוצות של מערכות של משוואות ליניאריות עם מטריצות של מקדם משולש נמוך או עליון.

באופן רשמי יותר, a ו-b, result[i0, ..., iR-3, :, :] הוא הפתרון לבעיה op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] כשהערך של left_side הוא true או x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] כאשר left_side הוא false, הפתרון עבור המשתנה x (x) שבו op(a) נקבע לפי transpose_a, שיכול להיות אחד מהערכים הבאים:

  • NO_TRANSPOSE: ביצוע הפעולה באמצעות a כפי שהיא.
  • TRANSPOSE: יש לבצע פעולה על העברת a.
  • ADJOINT: יש לבצע פעולה על ההחלפה המצומדת של a.

נתוני הקלט ייקראו רק מהמשולש התחתון של a, אם הערך של lower הוא true או המשולש העליון של a, אחרת. נתוני הפלט מוחזרים באותו משולש, והערכים במשולש השני מוגדרים לפי היישום.

אם הערך unit_diagonal הוא true, ההטמעה יכולה להניח שהרכיבים האלכסוניים של a שווים ל-1, אחרת ההתנהגות לא מוגדרת.

לסוגים שמחושבים לפי כמות, הביצועים של dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

קלט

תווית שם סוג מגבלות
(I1) a t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1-C3)
(I2) b t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1-C4)
(I3) left_side קבוע מסוג i1 (C3)
(I4) lower קבוע מסוג i1
(I5) unit_diagonal קבוע מסוג i1
(I6) transpose_a טיפוסים בני מנייה (enum) של NO_TRANSPOSE, של TRANSPOSE ושל ADJOINT

פלט

שם סוג מגבלות
result t e n s o r f l o w, או t e n s n o l o li bre t e n s o l l o w, (C1)

מגבלות

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (ג2) 2 <= rank(a) = rank(b) = R.
  • (C3) הקשר בין shape(a) לבין shape(b) מוגדר כך:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

דוגמאות

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

סמנטיקה

נוצר ב-tuple של result מהערכים val.

קלט

תווית שם סוג מגבלות
(I1) val מספר הווריאנטים של הערכים (C1)

פלט

שם סוג מגבלות
result tuple (C1)

מגבלות

  • (C1) result מכיל סוג tuple<E0, ..., EN-1> כאשר Ei = type(val[i]).

דוגמאות

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

דוגמאות נוספות

uniform_dequantize

סמנטיקה

הפונקציה מבצעת המרה של הרכיבים של Tensor מקודד operand ל-Tenor עם נקודה צפה (floating-point) result בהתאם לפרמטרים של הקוונטיזציה שמוגדרים בסוג operand.

בצורה יותר רשמית, result = dequantize(operand).

קלט

תווית שם סוג מגבלות
(I1) operand קוונטי img_tensor (C1), (C2)

פלט

שם סוג מגבלות
result Tensor מסוג נקודה צפה (floating-point) (C1), (C2)

מגבלות

  • (C1) shape(operand) = shape(result).
  • (ג2) element_type(result) = expressed_type(operand).

דוגמאות

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

סמנטיקה

הפונקציה מבצעת המרה של רכיבים (floating-point tensor) או tensor quantized tensor operand ל-tensor מדוד result בהתאם לפרמטרים של הקוונטיזציה שמוגדרים על ידי הסוג result.

בצורה יותר רשמית,

  • אם is_float(operand):
    • result = quantize(operand, type(result)).
  • אם is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

קלט

תווית שם סוג מגבלות
(I1) operand Tensor של נקודה צפה (floating-point) או מסוג quantated (C1), (C2)

פלט

שם סוג מגבלות
result קוונטי img_tensor (C1), (C2)

מגבלות

  • (C1) shape(operand) = shape(result).
  • (ג2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

דוגמאות

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

בזמן

סמנטיקה

מפיקה את הפלט מהפעלה של פונקציית body 0 פעמים או יותר בזמן שהפונקציה cond יוצרת את הפלט true. באופן יותר רשמי, אפשר לבטא את הסמנטיקה באמצעות תחביר Python:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

ההתנהגות של לולאה אינסופית טרם נקבעה (#383).

קלט

תווית שם סוג מגבלות
(I1) operand מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C1-C3)
(I2) cond פונקציה (C1)
(I3) body פונקציה (C2)

פלט

שם סוג מגבלות
results מספר שונה של טנזורים, טנזורים או אסימונים קוונטיים (C3)

מגבלות

  • (C1) ב-cond יש סוג (T0, ..., TN-1) -> tensor<i1>, בעוד Ti = type(operand[i]).
  • (C2) body מכיל סוג (T0, ..., TN-1) -> (T0, ..., TN-1), כאשר Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

דוגמאות

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

דוגמאות נוספות

Xor

סמנטיקה

הפונקציה מבצעת את הפונקציה XOR (XOR) של שני הרכיבים lhs ו-rhs, ומפיקה את ה-result tensor. בהתאם לסוג הרכיב, מבצעים את הפעולות הבאות:

  • לבוליאנים: XOR לוגי.
  • למספרים שלמים: XOR ברמת הסיביות.

קלט

תווית שם סוג מגבלות
(I1) lhs tensor מסוג בוליאני או מספר שלם (C1)
(I2) rhs tensor מסוג בוליאני או מספר שלם (C1)

פלט

שם סוג מגבלות
result tensor מסוג בוליאני או מספר שלם (C1)

מגבלות

  • (C1) type(lhs) = type(rhs) = type(result).

דוגמאות

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

ביצוע

ביצוע ברצף

הפעלת תוכנית StableHLO מתבצעת על ידי הזנת ערכי קלט לפונקציה main וערכי פלט מחשוב. ערכי הפלט של פונקציה מחושבים על ידי הפעלת התרשים של פעולות שעברו רוט (Root) בפעולה המתאימה של return.

סדר הביצוע מוגדר לפי ההטמעה כל עוד הוא תואם לזרימת הנתונים, כלומר אם הפעולות מתבצעות לפני השימושים שלהן. ב-StableHLO, כל הפעולות שמשפיעות בצד הלקוח צורכות אסימון אחד ומפיקות אסימון אחד (אפשר למרוב כמה אסימונים לאסימון אחד באמצעות after_all), כך שסדר הביצוע של תופעות הלוואי גם תואם לזרימת הנתונים. סדרי ההפעלה האפשריים של התוכנית לדוגמה שלמעלה הם %0%1%2%3%4return או %3%0%0%1%2%4return.

באופן רשמי יותר, תהליך יציב הוא שילוב של: 1) תוכנית StableHLO, 2) סטטוסים של פעולות (עדיין לא בוצע, כבר בוצע) ו-3) ערכי ביניים שהתהליך עובד עליהם. התהליך מתחיל בערכי קלט לפונקציה main, מתקדם דרך תרשים הפעולות שמעדכן את הסטטוסים של הפעולות ואת ערכי הביניים, ומסתיים בערכי הפלט. תהליך רשמי נוסף טרם נקבע (#484).

ביצוע מקביל

אפשר להפעיל תוכנות StableHLO במקביל, לסדר אותן ברשת עיבוד דו-ממדית של num_replicas עד num_partitions, ששתיהן מסוג ui32.

ברשת התהליכים של StableHLO, מתבצעים num_replicas * num_partitions מתוך התהליכים של StableHLO בו-זמנית. לכל תהליך יש process_id = (replica_id, partition_id) ייחודי, כאשר replica_id ב-replica_ids = range(num_replicas) וב-partition_id ב-partition_ids = range(num_partitions) הם מסוג ui32.

הגודל של רשת התהליך ידוע באופן סטטי לכל תוכנה (בעתיד אנחנו מתכננים להפוך אותה לחלק מפורש מתוכניות StableHLO #650), והמיקום בתוך רשת התהליכים ידוע באופן סטטי בכל תהליך. לכל תהליך יש גישה למיקום שלו ברשת התהליך דרך הפעולות replica_id ו-partition_id.

בתוך רשת התהליכים, התוכנות יכולות להיות זהות (בסגנון "תוכנית יחידה, נתונים מרובים"), הן יכולות להיות כולן שונות (בסגנון "תוכנית מרובות, נתונים מרובים") או משהו ביניהן. בעתיד, אנחנו מתכננים להוסיף תמיכה במונחים אחרים של הגדרת תוכניות StableHLO מקבילות, כולל GSPMD (#619).

בתוך רשת התהליך, התהליכים בדרך כלל בלתי תלויים זה בזה – יש להם סטטוסים נפרדים של פעולות, ערכי קלט, ביניים ופלט נפרדים, ורוב הפעולות מבוצעות בנפרד בין התהליכים, למעט מספר קטן של פעולות קולקטיביות שמתוארות בהמשך.

מכיוון שהביצוע של רוב הפעולות משתמש רק בערכים מאותו תהליך, בדרך כלל לא ניתן להתייחס לערכים האלה בשמות שלהם. עם זאת, כשמתארים סמנטיקה של פעולות קולקטיביות, המשמעות היא שהיא לא מספקת ולכן מעלה את הסימון name@process_id שמתייחס לערך name בתהליך מסוים. (מנקודת המבט הזו, אפשר לראות את name ללא התאמה כקיצור של name@(replica_id(), partition_id())).

סדר הביצוע בתהליכים מוגדר, מלבד הסנכרון שמתבצע בתקשורת מנקודה לנקודה ובתפעול קולקטיבי, כפי שמתואר בהמשך.

תקשורת מנקודה לנקודה

תהליכי StableHLO יכולים לתקשר אחד עם השני באמצעות ערוצי StableHLO. ערוץ מיוצג על ידי מזהה חיובי מסוג si64. הפעולות השונות מאפשרות לשלוח ערכים לערוצים ולקבל אותם מערוצים.

תהליך רשמי נוסף, למשל מהיכן מגיעים מזהי הערוצים האלה, איך תהליכים מתודעים להם ואיזה סוג של סנכרון הם מתחילים לעשות, טרם נקבע (#484).

תקשורת סטרימינג

לכל תהליך StableHLO יש גישה לשני ממשקי סטרימינג:

  • בפיד שאפשר לקרוא ממנו.
  • Outfeed שאפשר לכתוב בה.

בניגוד לערוצים, המשמשים לתקשורת בין תהליכים ולכן יש בהם תהליכים בשני הקצוות, פידים של פידים ופידים חיצוניים מוגדרים באמצעות הטמעה אחרת.

עוד לא נקבע, למשל, איך תקשורת סטרימינג משפיעה על סדר הביצוע וסוג הסנכרון שמופעל באמצעותה (#484).

תפעול שיתופי

יש 6 פעולות קולקטיביות ב-SableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute ו-reduce_scatter. כל הפעולות האלה מפצלות את התהליכים ברשת התהליך של StableHLO לקבוצות של תהליכי StableHLO, ומבצעות חישוב משותף בתוך כל קבוצת תהליכים, בנפרד מקבוצות תהליכים אחרות.

בתוך כל קבוצת תהליכים, פעולות קולקטיביות עשויות ליצור מחסום לסנכרון. תהליך רשמי נוסף, למשל הבהרנו מתי בדיוק מתרחש הסנכרון, איך בדיוק התהליכים מגיעים למחסום הזה ומה יקרה אם לא תעשו זאת, טרם נקבע (#484).

אם קבוצת התהליכים כוללת תקשורת בין מחיצות, כלומר יש תהליכים בקבוצת התהליכים שמזהי המחיצות שלהם שונים, אז הביצוע של הפעולה הקולקטיבית צריך ערוץ, והתפעול הקולקטיבי צריך לספק channel_id חיובי מסוג si64. כדי ליצור קשרים שיש בהם העתקים לא צריך ערוצים.

החישובים שמבוצעים על ידי הפעולות הקולקטיביות הם ספציפיים לפעולות בודדות ומתוארים בקטעי תפעול נפרדים למעלה. עם זאת, האסטרטגיות שלפיהן רשת התהליכים מחולקת לקבוצות תהליכים משותפות בין הפעולות האלה ומתוארות בקטע הזה. באופן רשמי יותר, StableHLO תומך בארבע האסטרטגיות הבאות.

cross_replica

בתוך כל קבוצת תהליכים מתבצעת רק תקשורת עם העתקים שונים. בשיטה הזו נעשה שימוש ב-replica_groups – רשימה של מזהים של רפליקות – ומחשבת מכפלה קרטזית של replica_groups עד partition_ids. השדה replica_groups חייב לכלול רכיבים ייחודיים ולהכיל את כל הפריטים מסוג replica_ids. באופן רשמי יותר, באמצעות תחביר של Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הערך cross_replica יפיק [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

רק תקשורת בין מחיצות מתקיימת בתוך כל קבוצת תהליכים. בשיטה הזו משתמשים ב-partition_groups – רשימה של מזהי מחיצות – ומחשבת מכפלה קרטזית של partition_groups עד replica_ids. השדה partition_groups חייב לכלול רכיבים ייחודיים ולהכסה את כל רכיבי ה-partition_ids. בצורה יותר רשמית, באמצעות תחביר של Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

לדוגמה, עבור partition_groups = [[0, 1]] ו-num_replicas = 4, הערך cross_partition יפיק [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

תקשורת בין עותקים למחיצות יכולה להתרחש בתוך כל קבוצת תהליכים. בשיטה הזו נעשה שימוש ב-replica_groups – רשימת רשימות של מזהים כפולים, ומחשבת את המכפלה הקרטזית של כל replica_group לפי partition_ids. השדה replica_groups חייב לכלול אלמנטים ייחודיים ולהכסה את כל replica_ids. בצורה יותר רשמית, באמצעות תחביר של Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור replica_groups = [[0, 1], [2, 3]] ו-num_partitions = 2, הערך cross_replica_and_partition יפיק [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

בשיטה הזו נעשה שימוש ב-flattened_id_groups – רשימת רשימות של מזהי תהליכים 'פשוטים' בצורת replica_id * num_partitions + partition_id – והופכת אותם למזהי תהליכים. השדה flattened_id_groups חייב לכלול רכיבים ייחודיים ולהכיל את כל הפריטים מסוג process_ids. בצורה יותר רשמית, באמצעות תחביר של Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

לדוגמה, עבור flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 ו-num_partitions = 2, flattened_ids יפיק [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

דיוק

בשלב זה, לא ניתן להבטיח דיוק מספרי ב-SableHLO, אבל זה עשוי להשתנות בעתיד (#1156).

שגיאות

תוכניות StableHLO מאומתות באמצעות מערך נרחב של אילוצים בפעולות ספציפיות, כדי לבטל סיווגים רבים של שגיאות לפני תחילת זמן הריצה. עם זאת, עדיין ייתכנו תנאי שגיאה, כמו חריגות במספרים שלמים, כניסות מחוץ לתחום (out-of-bounds) וכו'. אם לא צוין אחרת במפורש, כל השגיאות האלה יובילו להתנהגות המוגדרת על-ידי ההטמעה, אבל הדבר עשוי להשתנות בעתיד (#1157).

למעט הכלל הזה, לחריגות בנקודה צפה (floating-point) בתוכניות של StableHLO יש התנהגות מוגדרת היטב. פעולות שגורמות לחריגות שמוגדרות בתקן IEEE-754 (פעולה לא חוקית, חלוקה לפי אפס, גלישה, חריגה בתהליך או חריגות לא מדויקות) מפיקות תוצאות ברירת מחדל (כפי שמוגדר בתקן) ומפעילות את הביצוע בלי להעלות את סימון הסטטוס התואם, בדומה לטיפול בחריגות מהתקן raiseNoFlag. בהגדרות ההטמעה מוגדרות חריגות לפעולות לא סטנדרטיות (למשל, פעולות חשבון מורכבות ופונקציות טרנסצנדנטליות מסוימות).

סימון

לצורך תיאור התחביר, במסמך זה נעשה שימוש בגרסה המתוקנת של תחביר EBNF (ISO/IEC 14977:1996, ויקיפדיה), עם שני שינויים: 1) כללים מוגדרים באמצעות ::= במקום באמצעות =.

2) השרשור בא לידי ביטוי באמצעות סמיכות ולא באמצעות ,.

לצורך תיאור הסמנטיקה (כלומר, בקטעים 'סוגים', 'קבועים' ו'אופציות'), אנחנו משתמשים בנוסחאות שמבוססות על תחביר Python מורחב עם תמיכה להבעה תמציתית של פעולות מערך, כפי שמתואר בהמשך. השיטה הזו פועלת היטב בקטעי קוד קטנים, אבל במקרים נדירים שבהם יש צורך בקטעי קוד גדולים יותר, אנחנו משתמשים בתחביר וניל ב-Python שתמיד מיושם באופן מפורש.

נוסחאות

עכשיו נראה איך הנוסחאות פועלות על סמך דוגמה מהמפרט dot_general. אחת מהמגבלות על הפעולה הזו נראית כך: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

השמות שכלולים בנוסחה הזו מגיעים משני מקורות: 1) פונקציות גלובליות, כלומר dim, 2) הגדרות חברים של רכיב התוכנית המתאים, כלומר הקלט lhs, lhs_batching_dimensions, rhs ו-rhs_batching_dimensions, שמוגדרות בקטע 'שיטות קלט' של dot_general.

כפי שצוין למעלה, התחביר של הנוסחה הזו מבוסס על Python עם כמה תוספים מוכווני-תמצית. כדי להבין את הנוסחה, נמיר אותה לתחביר וניל ב-Python.

ת) בנוסחאות האלה אנחנו משתמשים ב-= כדי לייצג שוויון, ולכן השלב הראשון להשגת התחביר של Python הוא החלפה של = ב-==, באופן הבא: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

ב) כמו כן, הנוסחאות האלה תומכות בשלוש נקודות (...) שהופכות ביטויים סקלריים לביטויי טנזור. בקצרה, המשמעות של f(xs...) היא "לכל סקלרי x בטנזור xs, מחשבים f(x) סקלרי ואז מחזירים את כל התוצאות הסקלריות האלה יחד כתוצאת טנזור". בתחביר של וניל Python, הנוסחה לדוגמה הופכת ל: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

הודות לשלוש נקודות, ניתן לעיתים קרובות להימנע מעבודה ברמה של סקלרים בודדים. עם זאת, בחלק מהמקרים הבעייתיים אפשר להשתמש בתחביר חצי-לא-פורמלי ברמה נמוכה יותר, כמו בנוסחת start_indices[bi0, ..., :, ..., biN] מהמפרט של gather. למען התמציתיות, אנחנו לא מספקים פורמליזציה מדויקת לתרגום תחביר כזה לווניל Python, בתקווה שיהיה מובן אינטואיטיבי על בסיס כל מקרה לגופו. ספרו לנו אם נוסחאות מסוימות נראות אטומות, וננסה לשפר אותן.

בנוסף, שימו לב שנוסחאות משתמשות בשלוש נקודות כדי להרחיב רשימות מכל מיני סוגים, כולל טנזורים, רשימות טנזורים (שלמשל יכולות לנבוע ממספר משתנה של טנסרים) וכו'. זה עוד תחום שבו אנחנו לא מספקים פורמליות מדויקת (למשל, רשימות לא מסתמכות אפילו על שיטה אינטואיטיבית של StableHLO)

ג) אמצעי הסימון האחרון שבו אנחנו משתמשים הוא שידור מרומז. השיטה של StableHLO לא תומכת בשידור מרומז, אבל הנוסחאות כן תומכות גם בשירות של תמציתיות. בקצרה, אם משתמשים בסקלרי בהקשר שבו צפוי טנסור, הסקלרי משודר לצורה הצפויה.

כדי להמשיך את הדוגמה של dot_general, מציבים כאן מגבלה נוספת: 0 <= lhs_batching_dimensions < rank(lhs). כפי שמוגדר במפרט של dot_general, lhs_batching_dimensions הוא טנס, אבל גם 0 וגם rank(lhs) הם סקלריים. אחרי שנחיל שידור משתמע, הנוסחה תהפוך ל-[0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

כשמחילים אותה על פעולת dot_general מסוימת, היא תבצע הערכה ל-tensor של בוליאנים. כשנעשה שימוש בנוסחאות כאילוצים, האילוץ מתקיים אם הנוסחה מבוססת על true או על tenor שמכיל רק רכיבי true.

שמות

בנוסחאות, ההיקף המילוני כולל: 1) פונקציות גלובליות, 2) הגדרות של חברים,

3) הגדרות מקומיות. בהמשך מופיעה רשימת הפונקציות הגלובליות. רשימת הגדרות הרכיבים תלויה ברכיב התוכנית שעליו חל הסימון:

  • בהקשר של פעולות, הגדרות החברים כוללות שמות שהוצגו בקטעים 'קלט' ו'פלט'.
  • בכל שאר התנאים, הגדרות החברים כוללות חלקים מבניים של רכיב התוכנית, שנקראים על שם ה-EBNF הלא-טרמינלים המקבילים. ברוב המקרים, השמות של החלקים המבניים האלה מומרים על ידי המרת השמות של החלקים שאינם מסופים לאותיות קטנות (למשל IntegerLiteral => integer_literal), אבל לפעמים נעשה שימוש מקוצר בשמות (למשל QuantizationStorageType => storage_type). במקרה כזה השמות מופיעים באופן ספציפי באופן דומה לפעולות קלט / פלט בקטעים.
  • בנוסף, הגדרות החברים תמיד כוללות self כדי להתייחס לרכיב התוכנית התואם.

ערכים

כשמתבצעת הערכה של נוסחאות, הן עובדות עם סוגי הערכים הבאים: 1) Value (ערכים בפועל, לדוגמה dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; הם תמיד יודעים את הסוגים שלהם), 2) Placeholder (ערכים עתידיים, למשל lhs, rhs או result; הערכים שלהם בפועל עדיין לא ידועים, רק הסוגים שלהם ידועים), 3) Type (סוגים שמוגדרים בקטע 'סוגים'), 4) Function (פונקציות מוגדרות בקטע 'גלובלית').

בהתאם להקשר, ייתכן שהשמות מתייחסים לערכים שונים. באופן ספציפי יותר, הקטע 'סמנטיקה' לתפעול (ופונקציות מקבילות לרכיבי תוכנה אחרים) מגדיר את הלוגיקה של זמן הריצה, כך שכל מקורות הקלט זמינים בתור Value. לעומת זאת, הקטע 'אילוצים' עבור פעולות פעולה (ופונקציות מקבילות) מגדיר את לוגיקת "זמן הידור", כלומר משהו שמבוצע בדרך כלל לפני זמן הריצה, ולכן רק קלט קבוע זמין כ-Value וסוגים אחרים של קלט זמינים רק בתור Placeholder.

שמות ב "סמנטיקה" בקטע "אילוצים"
פונקציות גלובליות Function Function
קלט קבוע Value Value
ערכי קלט לא קבועים Value Placeholder
פלט Value Placeholder
הגדרות מקומיות תלוי בהגדרה תלוי בהגדרה

נבחן את פעולת transpose לדוגמה:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

בפעולה הזו, permutation הוא קבוע ולכן הוא זמין בתור Value גם בסמנטיקה וגם באילוצים. לעומת זאת, operand ו-result זמינים בתור Value בסמנטיקה, אבל רק כ-Placeholder באילוצים.

פונקציות

בנייה של סוגים

אין פונקציות שבהן ניתן להשתמש כדי לבנות סוגים. במקום זאת, אנחנו משתמשים ישירות בתחביר של סוג, מפני שהוא בדרך כלל תמציתי יותר. למשל (tensor<E>, tensor<E>) -> (tensor<E>) במקום function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

פונקציות על סוגים

  • element_type מוגדר לסוגי tenor ולסוגי tensor מותנים ותחזרות, בהתאמה, ה-TensorElementType או QuantizedTensorElementType של ה-TensorType או ה-QuantizedTensorType התואמים.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך ל-is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך בשביל is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool בודק אם אפשר לקדם את הסוג x לסוג y. כשהערך בשדה x וב-y הוא QuantizedTensorElementType, המבצע חל רק על storage_type. הגרסה הספציפית הזו של המבצע נמצאת כרגע בשימוש בהקשר של חישוב הפחתת מידע (לפרטים נוספים, ראו RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value הוא קיצור דרך אל is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. זמין לכל הסוגים. למשל, הפונקציה is_float(x) מחזירה true אם הערך של x הוא FloatType. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור דרך ל-is_type_name(type(x)).

  • max_value(x: Type) -> Value מחזירה את הערך המקסימלי של TensorElementType. אם x אינו TensorElementType, הפונקציה מחזירה את הערך None.

  • min_value(x: Type) -> Value מחזירה את הערך המינימלי האפשרי של TensorElementType. אם x אינו TensorElementType, הפונקציה מחזירה את הערך None.

  • member_name(x: Value | Placeholder | Type) -> Any. זמין לכל הגדרות המנויים member_name מכל הסוגים. לדוגמה, tensor_element_type(x) מחזירה את החלק TensorElementType של TensorType תואם. אם x הוא ערך או placeholder, הפונקציה הזו היא קיצור דרך ל-member_name(type(x)). אם x הוא לא סוג שיש לו איבר מתאים, או ערך או placeholder מסוג כזה, הפונקציה מחזירה None.

בניית ערכים

  • operation_name(*xs: Value | Type) -> Value. זמין לכל הפעולות. לדוגמה, add(lhs, rhs) לוקחת את שני ערכי הטנזור lhs ו-rhs ומחזירה את הפלט של הערכה של הפעולה add עם ערכי הקלט האלה. בפעולות מסוימות, למשל broadcast_in_dim, סוגי הפלט שלהן הם "נושאת עומס", כלומר נדרשים כדי להעריך פעולה. במקרה כזה, הפונקציה לוקחת את הסוגים האלה כארגומנטים.

פונקציה על ערכים

  • כל האופרטורים והפונקציות של Python זמינים. לדוגמה, אפשר להוסיף לאינדקס גם את הערכים subscription וגם את slicing מ-Python ל-tensors, ל-tensors ול-tuples.

  • to_destination_type(x: Value, destination_type: Type) -> Value מוגדר על ידי Tensor ומחזיר את הערך המומר של x על סמך type(x) ו-destination_type, באופן הבא:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

מתקיים דיון מוקדם על מיזוג הפעולות של convert, uniform_quantize ו-uniform_dequantize (#1576). אחרי המיזוג אין צורך בפונקציה שלמעלה, ונוכל להשתמש בשם הפעולה ל-convert במקום זאת.

  • is_nan(x: Value) -> Value מוגדר בערכי tensor ומחזיר true אם כל הרכיבים של x הם NaN או false אחרת. אם x הוא לא טנזור, הפונקציה מחזירה את הערך None.

  • is_sorted(x: Value) -> Value מוגדר על tensor ומחזיר true אם הרכיבים של x ממוינים בסדר עולה ביחס לסדר המילולי עולה של האינדקסים, ואם לא, false. אם x הוא לא Tenor, הפונקציה מחזירה את הערך None.

  • is_unique(x: Value) -> Value מוגדר בטנזורים ומחזיר true אם ל-x אין רכיבים כפולים, או ל-false אחרת. אם x הוא לא טנזור, הפונקציה מחזירה את הערך None.

  • member_name(x: Value) -> Any מוגדר לכל הגדרות החברים member_name של כל הערכים. לדוגמה, real_part(x) מחזירה את החלק RealPart של ComplexConstant תואם. אם x לא ערך שיש לו איבר מתאים, הפונקציה מחזירה את הערך None.

  • same(x: Value) -> Value מוגדר על טנזורים ומחזיר true אם האלמנטים של x שווים זה לזה, או false. אם הטנזור לא כולל רכיבים, הפונקציה נחשבת כ"כולם שווים זה לזה", כלומר הפונקציה מחזירה true. אם x אינו טנסור, הפונקציה מחזירה את הערך None.

  • הפונקציה split(x: Value, num_results: Value, axis: Value) -> Value מוגדרת על ידי Tensor ומחזירה num_results פרוסות של x לאורך הציר axis. אם x אינו tensor או dim(x, axis) % num_results != 0, הפונקציה מחזירה את הערך None.

חישובי צורות

  • axes(x: Value | Placeholder | Type) -> Value הוא קיצור דרך אל range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value הוא קיצור דרך אל shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List הוא קיצור דרך אל list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value מוגדר על ידי טנזורים ומחזיר את האינדקסים size(x) עבור TensorType המתאים, כשהם ממוינים לפי סדר מילולי עולה, כלומר [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. אם x הוא לא סוג של tensor, סוג tensor מסוים או ערך, או placeholder של אחד מהסוגים האלה, הפונקציה מחזירה None.

  • rank(x: Value | Placeholder | Type) -> Value הוא קיצור דרך אל size(shape(x)).

  • השדה shape(x: Value | Placeholder | Type) -> Value מוגדר בקטע Functions on types (פונקציות על סוגים) דרך member_name.

  • size(x: Value | Placeholder | Type) -> Value הוא קיצור דרך אל reduce(lambda x, y: x * y, shape(x)).

חישובי כמות

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type הוא קיצור דרך בשביל element_type(baseline_type(x)).

  • השדה baseline_type מוגדר לסוגי tenor ולסוגי tensor מכווננים והופך אותם ל-"baseline", כלומר סוג עם אותה צורה אבל הפרמטרים של הקוונטיזציה של סוג הרכיב מתאפסים לערכי ברירת המחדל. הוא משמש כטריק שימושי להשוואה בין סוגי t tensor ו-tensor. בסוגים ממוספרים, כך אפשר להשוות בין סוגים שמתעלמים מפרמטרים של קוונטיזציה, כלומר: shape, storage_type, expressed_type, storage_min, storage_max ו-quantization_dimension (לסוג ספציפי לציר שמבוסס על צירים), אבל ייתכנו הבדלים בין scales לבין zero points.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize מוגדר בסוגי tensor מכווננים והופך אותם לסוגים של tensor עם נקודה צפה (floating-point). זה קורה באמצעות המרה של רכיבים כמותיים, שמייצגים ערכים של מספרים שלמים מסוג האחסון לערכי נקודות צפות (floating-point) תואמים מהסוג שמבוטא באמצעות נקודת האפס וקנה המידה המשויכים לסוג הרכיב הזה.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize מוגדר בסוגי tensor עם נקודה צפה (floating-point) והופך אותם לסוגי tensor מתוחכמים. זה קורה באמצעות המרה של ערכים של נקודה צפה (floating-point) מהסוג המבוטא לערכים של מספרים שלמים תואמים מסוג האחסון באמצעות נקודת האפס וקנה המידה שמשויכים לסוג הרכיב שצוין.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize משמש לציון חישובים ברמת הרכיבים בערכי tensor קוונטיים. היא מבצעת פעולת חיסור, כלומר ממירה רכיבים קוונטיים לסוגים המיוצגים באמצעותם, ואז מבצעת פעולה ולאחר מכן מבצעת פעולת כימת, כלומר ממירה את התוצאות חזרה לסוגי האחסון שלהם. בשלב זה, הפונקציה הזו פועלת רק עבור קוונטיזציה לכל טנזור. אנחנו עובדים על קוונטיזציה לכל ציר (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

חישובי רשת

  • cross_partition(replica_groups: Value) -> Value. עיינו בקטע 'cross_Replica' למעלה.

  • cross_replica(replica_groups: Value) -> Value. עיינו בקטע 'cross_Replica' למעלה.

  • cross_replica_and_partition(replica_groups: Value) -> Value. ניתן לעיין בקטע "cross_Replica_and_partition" למעלה.

  • flattened_ids(replica_groups: Value) -> Value. עיינו בקטע "flattened_id" למעלה.