StableHLO 規格

StableHLO 是機器學習 (ML) 模型中高階運算 (HLO) 的運算集。StableHLO 可做為不同 ML 架構和 ML 編譯器之間的可攜性層:產生 StableHLO 程式的 ML 架構與使用 StableHLO 程式的 ML 編譯器相容。

我們的目標是透過在各種機器學習架構 (例如 TensorFlow、JAX 和 PyTorch) 與機器學習編譯器 (例如 XLA 和 IREE) 之間建立更多互通性,簡化機器學習開發流程並加快開發速度。為此,本文件提供 StableHLO 程式設計語言的規範。

這個規格包含三個主要部分首先,程式一節說明 StableHLO 程式的結構,其中包含 StableHLO 函式,而這些函式本身則包含 StableHLO 運算。在該結構中,「Ops」部分會指定個別操作的語意。「Execution」部分會為在程式中一起執行的所有運算提供語意。最後,符號一節會討論規格中使用的符號。

如要查看 StableHLO 先前版本的規格,請開啟您感興趣的標記版本的存放區。例如 StableHLO v0.19.0 規格。如要查看 StableHLO 在每個子版本遞增發生的變更,請參閱 VhloDialect.td 中的版本記錄。

程式

Program ::= {Func}

StableHLO 程式包含任意數量的 StableHLO 函式。以下是含有函式 @main 的程式範例,該函式有 3 個輸入 (%image%weights%bias) 和 1 個輸出。函式主體有 6 個運算。

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

函式

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 函式 (也稱為已命名函式) 包含一個 ID、輸入/輸出項目和主體。我們計劃日後為函式推出其他中繼資料,以提升與 HLO 的相容性 (#425#626#740#744)。

ID

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO ID 與許多程式語言中的 ID 類似,但有兩個特點:1) 所有 ID 都有符號,用於區分不同類型的 ID;2) 值 ID 可以完全以數字表示,簡化 StableHLO 程式的產生作業。

類型

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 類型分為值類型 (也稱為第一類型),代表 StableHLO 值,以及非值類型,用於描述其他程式元素。StableHLO 類型與許多程式設計語言中的類型相似,主要特徵是 StableHLO 的特定領域性質,會產生一些不尋常的結果 (例如,純量類型不是值類型)。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

張量類型代表張量,也就是多維陣列。包含形狀元素類型,其中形狀代表依據對應維度 (也稱為「ax」) 的遞增順序,從 0R-1 的遞增順序表示非負數或不明維度大小。維度 R 的數量稱為「排名」。舉例來說,tensor<2x3xf32> 是形狀為 2x3 的張量類型,元素類型為 f32。它有兩個維度 (或稱兩個軸) - 第 0 維度和第 1 維度,大小分別為 2 和 3。其排名為 2。

形狀可能部分或完全不明 (動態),例如 tensor<?x2xf64> 部分不明,而 tensor<?x?xf64> 完全不明。動態維度大小會使用 ? 表示。形狀無法排序。

我們未來會嘗試將張量類型擴展至維度大小和元素類型以外的其他類型,例如納入版面配置 (#629) 和稀疏度 (#1078)。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
名稱 類型 限制
storage_type 整數類型 (C1-C3)、(C8)
storage_min 整數常數 (C1)、(C3)、(C7)
storage_max 整數常數 (C2)、(C3)、(C7)
expressed_type 浮點類型 (C4)
quantization_dimension 選用整數常數 (C10-C12)
scales 浮點常數數量 (C4-C6)、(C9)、(C10)、(C13)
zero_points 整數常數的變數參數數量 (C7-C9)

量化元素類型代表儲存類型的整數值,範圍為 storage_minstorage_max (含),對應於表示型別的浮點值。對於指定的整數值 i,對應的浮點值 f 可計算為 f = (i - zero_point) * scale,其中 scalezero_point 稱為量化參數storage_minstorage_max 是文法中的選用項目,但預設值為 min_value(storage_type)max_value(storage_type)。量化元素類型具有下列限制:

  • (C1) type(storage_min) = storage_type
  • (C2) type(storage_max) = storage_type
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C4) type(scales...) = expressed_type
  • (C5) 0 < scales
  • (C6) is_finite(scales...)
  • (C7) storage_min <= zero_points <= storage_max
  • (C8) type(zero_points...) = storage_type
  • (C9) size(scales) = size(zero_points)
  • (C10) 如果 is_empty(quantization_dimension),則 size(scales) = 1
  • (C11) 0 <= quantization_dimension

目前 QuantizationScale 是浮點常數,但使用者對以整數為基礎的比例表達方式 (以乘數和位移表示) 非常感興趣。我們預計在近期內探索這項功能。(#1404)。

我們目前正在討論 QuantizationZeroPoint 的語意,包括類型、值,以及在量化張量類型中是否只能有一個或多個零點。根據這次討論的結果,零點的規格日後可能會改變 (#1405)。

另一個正在討論的議題是 QuantizationStorageMinQuantizationStorageMax 的語意,以決定是否應對這些值和量化張量值施加任何限制 (#1406)。

最後,我們計劃探索如何呈現未知的比例和零點,就像我們計劃探索如何呈現未知的維度大小一樣 (#1407)。

量化張量類型代表具有量化元素的張量。這些張量與一般張量完全相同,但元素具有量化元素類型,而非一般元素類型。

在量化張量中,量化可以是「每個張量」,也就是為整個張量提供一個 scalezero_point;也可以是「每個軸」,也就是為特定維度 quantization_dimension 的每個切片提供一組 scaleszero_points。更正式來說,在張量 t 中,每軸量化會有 quantization_dimensiondim(t, quantization_dimension) 配量:t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] 等。i 切片中的所有元素都使用 scales[i]zero_points[i] 做為量化參數。量化張量類型有下列限制:

  • 針對每個張量量化:
    • 沒有其他限制。
  • 針對每個軸的量化:
    • (C12) quantization_dimension < rank(self)
    • (C13) dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

符記類型代表符記,也就是某些作業產生及消耗的非公開值。符記可用來為作業施加執行順序,如執行一節所述。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

元組類型代表元組,也就是異質清單。元組是舊版功能,僅存在於 HLO 相容性。在 HLO 中,元組用於表示可變參數輸入和輸出。在 StableHLO 中,變數輸入和輸出是原生支援的,而 StableHLO 中唯一的用途是全面代表 HLO ABI,例如 Ttuple<T>tuple<tuple<T>> 可能會因特定實作而有實質差異。日後,我們預計會對 HLO ABI 進行變更,這可能會讓我們從 StableHLO 中移除元組類型 (#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

元素類型代表張量類型的元素。與許多程式設計語言不同,這些類型並不是 StableHLO 中的第一類別。這表示 StableHLO 程式無法直接表示這些類型的值 (因此,使用 tensor<T> 類型的 0 維度張量值,是慣用 T 類型的純量值)。

  • 布林值類型代表布林值 truefalse
  • 整數類型可以是帶正負號 (si) 或不帶正負號 (ui),並且具有支援的位元寬度 (248163264)。帶正負號的 siN 類型代表從 -2^(N-1)2^(N-1)-1 的整數值 (含括在內),而未帶正負號的 uiN 類型則代表從 02^N-1 的整數值 (含括在內)。
  • 浮點類型可以是下列任一值:
  • 複雜類型代表包含相同元素類型「實際部分」和「虛部分」的複雜值。支援的複合類型為 complex<f32> (兩個部分皆為 f32 類型) 和 complex<f64> (兩個部分皆為 f64 類型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

函式類型代表已命名和匿名函式。這些資料類型具有輸入類型 (-> 左側的類型清單) 和輸出類型 (-> 右側的類型清單)。在許多程式設計語言中,函式類型是第一類,但在 StableHLO 中則不是。

StringType ::= 'string'

字串型別代表位元組序列。與許多程式語言不同,字串類型在 StableHLO 中並非第一類型,只用於為程式元素指定靜態中繼資料。

作業

StableHLO 運算 (也稱為 ops) 代表機器學習模型中一組封閉的高階運算。如上所述,StableHLO 語法受到 MLIR 的啟發,這不一定是最符合人體工學的替代方案,但最適合 StableHLO 的目標:在機器學習架構和機器學習編譯器之間建立更高的互通性。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO 作業 (也稱為 ops) 具有名稱、輸入/輸出和簽章。名稱包含 stablehlo. 前置字元和助憶法,可用於唯一識別其中一個支援的作業。請參閱下列完整清單,瞭解所有支援的作業。

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

作業會消耗輸入並產生輸出。輸入內容分為輸入值 (在執行期間計算)、輸入函式 (以靜態方式提供,因為在 StableHLO 中,函式不是一級值) 和輸入屬性 (同樣以靜態方式提供)。運算消耗及產生的輸入和輸出種類取決於其記憶法。舉例來說,add 運算子會使用 2 個輸入值,並產生 1 個輸出值。相比之下,select_and_scatter 運算會使用 3 個輸入值、2 個輸入函式和 3 個輸入屬性。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

輸入函式 (也稱為「匿名函式」) 與已命名函式非常相似,但有以下差異:1) 沒有 ID (因此稱為「匿名」),2) 不會宣告輸出類型 (輸出類型會從函式中的 return op 推斷而來)。

輸入函式的語法包含目前未使用的部分 (請參閱上述的 Unused 正式環境),這是能與 MLIR 相容性的部分。在 MLIR 中,有一個更廣義的「區域」概念,可透過跳躍運算子連結多個「運算子」的「區塊」。這些區塊的 ID 會對應至 Unused 產出作業,因此可彼此區分。StableHLO 沒有跳躍運算,因此 MLIR 語法的對應部分未使用 (但仍存在)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

輸入屬性具有名稱和值,其中值為支援的常數之一。這是為節目元素指定靜態中繼資料的主要方式。舉例來說,concatenate 運算子會使用屬性 dimension 指定輸入值連接的維度。同樣地,slice 運算子會使用 start_indiceslimit_indices 等多個屬性,指定用於切片輸入值的邊界。

目前,實際運作的 StableHLO 程式有時會包含本文件未提及的屬性。我們計劃日後將這些屬性吸收至 StableHLO 運算元,或禁止這些屬性出現在 StableHLO 程式中。在此期間,請參閱下列屬性清單:

  • layout (#629)。
  • mhlo.frontend_attributes (#628)。
  • mhlo.sharding (#619)。
  • output_operand_aliases (#740)。
  • 地點中繼資料 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

運算子簽章包含所有輸入值的類型 (-> 左側的類型清單) 和所有輸出值的類型 (-> 右側的類型清單)。嚴格來說,輸入類型是多餘的,輸出類型也幾乎總是多餘的 (因為對於大多數 StableHLO 運算子,輸出類型可從輸入推斷)。不過,為了與 MLIR 相容,我們刻意將運算子簽章納入 StableHLO 語法。

以下是其記憶法為 select_and_scatter 的運算範例。這個函式會使用 3 個輸入值 (%operand%source%init_value)、2 個輸入函式和 3 個輸入屬性 (window_dimensionswindow_stridespadding)。請注意,這個函式的簽章只包含輸入值的類型 (但不包含內嵌提供的輸入函式和屬性類型)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

常數

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 常數具有常值和類型,兩者共同代表 StableHLO 值。一般來說,類型是常數語法的一部分,除非是不混淆的情況 (例如布林值常數具有 i1 類型,而整數常數可以有多個可能的類型)。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

布林值常數代表布林值 truefalse。布林常數具有 i1 類型。

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整數常數會透過使用十六進制標記法的字串表示整數值。系統不支援其他基數,例如二進位或八進位。整數常數有下列限制:

  • (C1) is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮點常數會透過使用十進位或科學記號的字串表示浮點值。此外,十六進位記號法可用於直接指定對應類型的浮點格式中的基本位元。浮點常數具有下列限制:

  • (C1) 如果使用非十六進位標記法,is_wellformed(float_literal, float_type)
  • (C2) 如果使用十六進位表示法,則為 size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

複數常數會使用實數部分 (先出現) 和虛數部分 (後出現) 的清單,代表複數值。例如,(1.0, 0.0) : complex<f32> 代表 1.0 + 0.0i(0.0, 1.0) : complex<f32> 代表 0.0 + 1.0i。這些部分儲存在記憶體中的順序是由實作定義。複雜常數有下列限制:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type))
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Tensor 常數使用 NumPy 標記法指定的巢狀清單來代表張量值。舉例來說,dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 代表張量值,將以下從索引對應至元素:{0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6。這些元素在記憶體中儲存的順序由實作方式定義。張量常數有下列限制:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)),其中:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
  • (C2) has_shape(tensor_literal, shape(tensor_type)),其中:
    • has_shape(element_literal: Syntax, []) = true
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
    • 否則為 false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

量化張量常數會使用與張量常數相同的符號表示量化張量值,其中元素會指定為其儲存類型的常數。量化張量常數有下列限制:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

字串常值由使用 ASCII 字元和逸出序列指定的位元組所組成。這些值不受編碼方式影響,因此這些位元組的解讀方式由實作方式定義。字串常值的類型為 string

作業數

腹部

語義學

operand 張量執行元素級別的絕對值運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 如為帶號整數:整數模數。
  • 浮點值:IEEE-754 的 abs
  • 針對複數:複數的模數。
  • 對於量化類型:dequantize_op_quantize(abs, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 帶正負號整數、浮點數或複雜類型或每個張量的量化張量 (C1-C2)

輸出內容

名稱 類型 限制
result 有符號整數或浮點類型的張量,或每個張量的量化張量 (C1-C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) baseline_element_type(result) 的定義為:
    • is_complex(operand) 則設為 complex_element_type(element_type(operand))
    • 否則為 baseline_element_type(operand)

範例

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

更多範例

add

語義學

這個外掛程式能執行兩個張量 lhsrhs 的元素相關新增,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 OR。
  • 整數:加整數。
  • 浮點值:IEEE-754 的 addition
  • 複數:複數加法。
  • 對於量化類型:dequantize_op_quantize(add, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或量化張量 (C1-C6)
(I2) rhs 張量或量化張量 (C1-C5)、(C7)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1-C7)

限制

  • 如果運算使用非量化的張量:
    • (C1) type(lhs) = type(rhs) = type(result)
  • 如果運算使用經過量化的張量:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result)
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
    • (C6) 如果 is_per_axis_quantized(lhs),則 quantization_dimension(lhs) = quantization_dimension(result)
    • (C7) 如果 is_per_axis_quantized(rhs),則 quantization_dimension(rhs) = quantization_dimension(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

更多範例

after_all

語義學

確保產生 inputs 的作業會在任何依賴 result 的作業之前執行。這項作業不執行任何動作,只是為了建立從 resultinputs 的資料依附元件。

輸入

標籤 名稱 類型
(I1)。 inputs token 的變數參數數量

輸出內容

名稱 類型
result token

範例

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

更多範例

all_gather

語義學

在 StableHLO 程序格線的每個程序群組中,沿著 all_gather_dim 連結每個程序的 operands 張量值,並產生 results 張量。

這項作業會將 StableHLO 處理格線分割為 process_groups,其定義如下:

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false
  • 如果為 channel_id > 0 and use_global_device_ids = false,則為 cross_replica_and_partition(replica_groups)
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true

接著,在每個 process_group 中:

  • process_group 中的所有 receiveroperands...@receiver = [operand@sender for sender in process_group]
  • process_group 中的所有 processresults...@process = concatenate(operands...@process, all_gather_dim)

輸入

標籤 名稱 類型 限制
(I1) operands 變異量或每個張量量化張量 (C1)、(C6)
(I2) all_gather_dim 類型為 si64 的常數 (C1)、(C6)
(I3) replica_groups 類型為 si64 的 2 維張量常數 (C2-C4)
(I4) channel_id 類型為 si64 的常數 (C5)
(I5) use_global_device_ids i1 類型的常數 (C5)

輸出內容

名稱 類型 限制
results 可變數張量數量或每個張量量化張量 (C6)。

限制

  • (C1) 0 <= all_gather_dim < rank(operands...)
  • (C2) is_unique(replica_groups)
  • (C3) size(replica_groups) 的定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C4) 0 <= replica_groups < size(replica_groups)
  • (C5) 如果 use_global_device_ids = true,則 channel_id > 0
  • (C6) type(results...) = type(operands...) 除下列情況外:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)

範例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

更多範例

all_reduce

語義學

在 StableHLO 程序網格中的每個程序群組內,將縮減函式 computation 套用至每個程序的 operands 張量值,並產生 results 張量。

這項作業會將 StableHLO 處理格線分割為 process_groups,其定義如下:

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false
  • 如果為 channel_id > 0 and use_global_device_ids = false,則為 cross_replica_and_partition(replica_groups)
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true

接著,在每個 process_group 中:

  • 部分二進位檔樹狀結構 scheduleresults...@process[result_index] = exec(schedule),其中:
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是實作定義的二元樹,其內部順序遍歷是 to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))

輸入

標籤 名稱 類型 限制
(I1) operands 可變數張量數量或每個張量量化張量 (C5)、(C6)
(I2) replica_groups 類型為 si64 的 1 維張量常數變數數量 (C1-C3)
(I3) channel_id si64 類型的常數 (C4)
(I4) use_global_device_ids 類型為 i1 的常數 (C4)
(I5)。 computation 函式 (C5)

輸出內容

名稱 類型 限制
results 可變數張量數量或每個張量量化張量 (C6-C7)

限制

  • (C1) is_unique(replica_groups)
  • (C2) size(replica_groups) 的定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C3) 0 <= replica_groups < size(replica_groups)
  • (C4) 如果 use_global_device_ids = true,則 channel_id > 0
  • (C5) computation 具有 (tensor<E>, tensor<E>) -> (tensor<E>) 類型,其中 is_promotable(element_type(operand), E)
  • (C6) shape(results...) = shape(operands...)
  • (C7) element_type(results...) = E

範例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

更多範例

all_to_all

語義學

all_to_all

在 StableHLO 程序格線的每個程序群組中,沿著 split_dimensionoperands 張量的值分割成部分,在程序之間散布分割的部分,沿著 concat_dimension 連結散布的部分,並產生 results 張量。這項作業會將 StableHLO 處理格線分割為 process_groups,其定義如下:

  • channel_id <= 0 則設為 cross_replica(replica_groups)
  • channel_id > 0 則設為 cross_partition(replica_groups)

之後,在每個 process_group 內:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) process_group 中的所有 sender
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group],其中 receiver_index = process_group.index(receiver)
  • results...@process = concatenate(scattered_parts...@process, concat_dimension)

輸入

標籤 名稱 類型 限制
(I1) operands 可變數張量數量或每個張量量化張量 (C1-C3)、(C9)
(I2)。 split_dimension 類型為 si64 的常數 (C1)、(C2)、(C9)
(I3)。 concat_dimension 類型為 si64 的常數 (C3)、(C9)
(I4) split_count si64 類型的常數 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups 類型為 si64 的 2 維張量常數 (C5-C8)
(I6) channel_id 類型為 si64 的常數

輸出內容

名稱 類型 限制
results 變異量或每個張量量化張量 (C9)

限制

  • (C1) 0 <= split_dimension < rank(operands...)
  • (C2) dim(operands..., split_dimension) % split_count = 0
  • (C3) 0 <= concat_dimension < rank(operands...)
  • (C4) 0 < split_count
  • (C5) is_unique(replica_groups)
  • (C6) size(replica_groups) 的定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C7) 0 <= replica_groups < size(replica_groups)
  • (C8) dim(replica_groups, 1) = split_count
  • (C9) type(results...) = type(operands...),但 split_dimension != concat_dimension 除外:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count

範例

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

更多範例

語義學

對兩個張量 lhsrhs 執行元素 AND 運算,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 AND。
  • 整數:位元 AND。

輸入

標籤 名稱 類型 限制
(I1)。 lhs 布林值或整數類型的張量 (C1)
(I2) rhs 布林值或整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

更多範例

atan2

語義學

lhsrhs 張量執行元素級別 atan2 運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 atan2
  • 複數:complex atan2。
  • 對於量化類型:dequantize_op_quantize(atan2, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 浮點或複雜型別或每個張量的量化張量 (C1)
(I2) rhs 浮點或複數類型的張量,或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

更多範例

batch_norm_grad

語義學

計算 batch_norm_traininggrad_output 反向傳播的多個輸入的梯度,並產生 grad_operandgrad_scalegrad_offset 張量。更正式地說,此作業可使用 Python 語法表示為現有的 StableHLO 作業進行分解,如下所示:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

如果是量化類型,請執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型或每個張量的量化張量 (C1-C3)、(C5)
(I2) scale 1D 浮點或個別張量量化類型的 1D 張量 (C2)、(C4)、(C5)
(I3) mean 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I4) variance 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I5) grad_output 浮點類型的張量或每個張量的量化張量 (C2)、(C3)
(I6)。 epsilon 類型為 f32 的常數
(I7) feature_index si64 類型的常數 (C1)、(C5)

輸出內容

名稱 類型 限制
grad_operand 浮點類型或每個張量的量化張量 (C2)、(C3)
grad_scale 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
grad_offset 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offset 具有相同的 baseline_element_type
  • (C3) operandgrad_outputgrad_operand 具有相同的形狀。
  • (C4) scalemeanvariancegrad_scalegrad_offset 具有相同的形狀。
  • (C5) size(scale) = dim(operand, feature_index)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

語義學

會將 operand 張量在所有維度中進行規格化,但不包括 feature_index 維度,並產生 result 張量。更正式地說,這個運算可使用 Python 語法表示為對現有 StableHLO 運算的細分,如下所示:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

針對量化類型,執行 dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型或每個張量的量化張量 (C1-C7)
(I2) scale 浮點或每個張量量化類型的 1 維張量 (C2)、(C3)
(I3) offset 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I4) mean 1D 浮點或個別張量量化類型的 1D 張量 (C5)
(I5)。 variance 浮點或每個張量量化類型的 1 維張量 (C2)、(C6)
(I6) epsilon 類型為 f32 的常數
(I7) feature_index 類型為 si64 的常數 (C1)、(C3-C6)

輸出內容

名稱 類型 限制
result 浮點類型的張量或每個張量的量化張量 (C2)、(C7)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetmeanvarianceresult 具有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(mean) = dim(operand, feature_index)
  • (C6) size(variance) = dim(operand, feature_index)
  • (C7) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

語義學

計算所有維度的平均值和變異數 (feature_index 維度除外),並將 operand 張量標準化,產生 outputbatch_meanbatch_var 張量。更正式地說,這個運算可使用 Python 語法表達為現有 StableHLO 運算的解構,如下所示:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

針對量化類型,執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)
(I2) scale 浮點值或每個張量量化的 1 維張量 (C2)、(C3)
(I3) offset 浮點值或每個張量量化的 1 維張量 (C2)、(C4)
(I4) epsilon 類型為 f32 的常數 (C1)、(C3-C6)
(I5)。 feature_index 類型為 si64 的常數 (C1)、(C3-C6)

輸出內容

名稱 類型 限制
output 浮點類型或每個張量的量化張量 (C7)
batch_mean 浮點值或每個張量量化的 1 維張量 (C2)、(C5)
batch_var 1D 浮點或個別張量量化 (C2)、(C6)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetbatch_meanbatch_varoutput 具有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(batch_mean) = dim(operand, feature_index)
  • (C6) size(batch_var) = dim(operand, feature_index)
  • (C7) baseline_type(output) = baseline_type(operand)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

語義學

operand 張量執行位元轉換運算,並產生 result 張量,其中整個 operand 張量的位元會使用 result 張量的類型重新解讀。

更正式地說,假設有 E = element_type(operand)E' = element_type(result)R = rank(operand)

  • 如果是 num_bits(E') < num_bits(E),則為 bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • 如果是 num_bits(E') > num_bits(E),則為 bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • 如果是 num_bits(E') = num_bits(E),則為 bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits 會傳回指定值的記憶體內表示法,而其行為是由實作定義,因為張量的確切表示法是由實作定義,元素類型的確切表示法也是由實作定義。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C2)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1-C2)

限制

  • (C1) 指定 E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand)
    • 如果是 num_bits(E') = num_bits(E)shape(result) = shape(operand)
    • 如果為 num_bits(E') < num_bits(E)
    • rank(result) = R + 1
    • 所有0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(result, R) * num_bits(E') = num_bits(E)
    • 如果為 num_bits(E') > num_bits(E)
    • rank(result) = R - 1
    • dim(result, i) = dim(operand, i) 適用於所有 0 <= i < R
    • dim(operand, R - 1) * num_bits(E) = num_bits(E')
  • (C2) 如果 is_complex(operand) or is_complex(result),則為 is_complex(operand) and is_complex(result)

範例

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

更多範例

broadcast_in_dim

語義學

透過複製 operand 張量中的資料,並產生 result 張量,擴充輸入張量的維度和/或秩。更正式的 result[result_index] = operand[operand_index],其中 axes(operand) 中的所有 d

  • dim(operand, d) = 1 則設為 operand_index[d] = 0
  • 否則傳回 operand_index[d] = result_index[broadcast_dimensions[d]]

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C2)、(C5-C6)
(I2) broadcast_dimensions 1 維 si64 類型張量常數 (C2-C6)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1)、(C3)、(C5-C6)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand),如果 !is_per_axis_quantized(operand)
    • element_type(operand)quantization_dimension(operand)scales(operand)zero_points(operand) 可能與 quantization_dimension(result)scales(result)zero_points(result) 之間不同。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 針對 axes(operand) 中的所有 d
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result)
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果是 dim(operand, quantization_dimension(operand)) = 1,則 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))

範例

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多範例

保護殼

語義學

根據 index 的值,從 branches 執行一項函式產生輸出內容。更正式的 result = selected_branch(),其中:

  • 0 <= index < size(branches) 則設為 selected_branch = branches[index]
  • 否則傳回 selected_branch = branches[-1]

輸入

標籤 名稱 類型 限制
(I1) index si32 類型的 0 維度張量
(I2) branches 可變函式數 (C1-C4)

輸出內容

名稱 類型 限制
results 變異量、量化張量或代詞 (C4)

限制

  • (C1) 0 < size(branches)
  • (C2) input_types(branches...) = []
  • (C3) same(output_types(branches...))
  • (C4) type(results...) = output_types(branches[0])

範例

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

更多範例

cbrt

語義學

operand 張量執行元素級別的立方根運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 rootn(x, 3)
  • 複數:複數立方根。
  • 針對量化類型:dequantize_op_quantize(cbrt, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

更多範例

ceil

語義學

operand 張量執行元素逐項的圓整,並產生 result 張量。實作 IEEE-754 規格中的 roundToIntegralTowardPositive 運算。如果是量化類型,請執行 dequantize_op_quantize(ceil, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點類型的張量或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

更多範例

cholesky

語義學

計算一批矩陣的 Cholesky 分解。

更正式地說,對於 index_space(result) 中的所有 iresult[i0, ..., iR-3, :, :]a[i0, ..., iR-3, :, :] 的 Cholesky 分解,以下三角 (如果 lowertrue) 或上三角 (如果 lowerfalse) 矩陣的形式呈現。相反三角形中的輸出值 (即嚴格上三角形或嚴格下三角形) 則由實作定義。

如果輸入矩陣不是 Hermitian 正定矩陣的 i,則行為為未定義。

針對量化類型,執行 dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))

輸入

標籤 名稱 類型 限制
(I1) a 浮點或複雜型別或每個張量的量化張量 (C1-C3)
(I2)。 lower i1 類型的 0 維張量常數

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(a) = baseline_type(result)
  • (C2) 2 <= rank(a)
  • (C3) dim(a, -2) = dim(a, -1)

範例

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

限制取值範圍

語義學

operand 張量的每個元素限制在最小值和最大值之間,並產生 result 張量。更正式的 result[result_index] = minimum(maximum(operand[result_index], min_element), max_element),其中 min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index]。針對量化類型,執行 dequantize_op_quantize(clamp, min, operand, max, type(result))

對複數強制排序會涉及意料之外的語意,因此我們未來預計會移除此運算對複數的支援 (#560)。

輸入

標籤 名稱 類型 限制
(I1) min 張量或每個張量量化張量 (C1)、(C3)
(I2)。 operand 張量或每個張量量化張量 (C1-C4)
(I3) max 張量或每個張量量化張量 (C2)、(C3)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C4)

限制

  • (C1) rank(min) = 0 or shape(min) = shape(operand)
  • (C2) rank(max) = 0 or shape(max) = shape(operand)
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4) baseline_type(operand) = baseline_type(result)

範例

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

更多範例

collective_broadcast

語義學

在 StableHLO 程序格線的每個程序群組中,將來源程序的 operand 張量值傳送至目標程序,並產生 result 張量。

該作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • channel_id <= 0 則設為 cross_replica(replica_groups)
  • 如果 channel_id > 0,則為 cross_partition(replica_groups)

之後,result@process 的值會由下列公式計算:

  • 如果有 i 導致程序位於 process_groups[i],則為 operand@process_groups[i, 0]
  • 否則為 broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C3)
(I2)。 replica_groups 類型為 si64 的 1 維張量常數變數數量 (C1)、(C2)
(I3) channel_id 類型為 si64 的常數

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C3)

限制

  • (C1) is_unique(replica_groups)
  • (C2) 0 <= replica_groups < N,其中 N 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C3) type(result) = type(operand)

範例

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

語義學

在 StableHLO 程序格線的每個程序群組中,將 operand 張量的值從來源程序傳送至目標程序,並產生 result 張量。

這項作業會將 StableHLO 處理格線分割為 process_groups,其定義如下:

  • 如果 channel_id <= 0,則為 cross_replica(source_target_pairs)
  • 如果 channel_id > 0,則為 cross_partition(source_target_pairs)

之後,result@process 的值會由下列公式計算:

  • operand@process_groups[i, 0],如果存在 i,則 process_groups[i, 1] = process
  • 否則為 broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C5)
(I2) source_target_pairs 類型為 si64 的 2 維張量常數 (C1-C4)
(I3) channel_id 類型為 si64 的常數

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)。

限制

  • (C1) dim(source_target_pairs, 1) = 2
  • (C2) is_unique(source_target_pairs[:, 0])
  • (C3) is_unique(source_target_pairs[:, 1])
  • (C4) 0 <= source_target_pairs < N,其中 N 的定義如下:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C5) type(result) = type(operand)

範例

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

更多範例

比較

語義學

根據 comparison_directioncompare_type 執行 lhsrhs 張元素依據元素的比較,並產生 result 張量。

comparison_directioncompare_type 的值具有以下語義:

布林值和整數元素類型:

  • EQlhs = rhs
  • NElhs != rhs
  • GElhs >= rhs
  • GTlhs > rhs
  • LElhs <= rhs
  • LTlhs < rhs

針對具有 compare_type = FLOAT 的浮點元素類型,運算會實作下列 IEEE-754 作業:

  • EQcompareQuietEqual
  • NEcompareQuietNotEqual
  • GEcompareQuietGreaterEqual
  • GTcompareQuietGreater
  • LEcompareQuietLessEqual
  • LTcompareQuietLess

對於含有 compare_type = TOTALORDER 的浮點元素類型,運算子會使用 IEEE-754 中的 totalOrdercompareQuietEqual 運算組合。

針對複雜的元素類型,系統會使用提供的 comparison_directioncompare_type,執行 (real, imag) 組合的字典排序比對。對複數強制排序會涉及意義不明的語意,因此我們未來計劃在 comparison_directionGEGTLELT 時,移除對複數的支援 (#560)。

針對量化類型,請執行 dequantize_compare(lhs, rhs, comparison_direction)

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1-C3)
(I2) rhs 張量或每個張量量化張量 (C1-C2)
(I3)。 comparison_direction EQNEGEGTLELT 的列舉
(I4) compare_type FLOATTOTALORDERSIGNEDUNSIGNED 的列舉 (C3)

輸出內容

名稱 類型 限制
result 布林值類型的張量 (C2)

限制

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2) shape(lhs) = shape(rhs) = shape(result)
  • (C3) compare_type 定義為:
    • is_signed_integer(element_type(lhs)) 則設為 SIGNED
    • 如果 is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)),則為 UNSIGNED
    • 如果是 is_float(element_type(lhs)),則為 FLOATTOTALORDER
    • is_complex(element_type(lhs)) 則設為 FLOAT

範例

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

更多範例

複雜

語義學

這個外掛程式能從一組實際值和虛構值 (lhsrhs) 執行元素轉換到複雜值,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) lhs 類型為 f32f64 的張量 (C1-C3)
(I2) rhs 類型為 f32f64 的張量 (C1)

輸出內容

名稱 類型 限制
result 複雜類型的張量 (C2)、(C3)

限制

  • (C1) type(lhs) = type(rhs)
  • (C2) shape(result) = shape(lhs)
  • (C3) element_type(result) 具有 complex<E> 類型,其中 E = element_type(lhs)

範例

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

更多範例

複合

語義學

封裝其他 StableHLO 作業 (由其他 StableHLO 作業組成) 的作業,執行 inputscomposite_attributes 並產生 resultsdecomposition 屬性會實作作業的語意。composite 運算子可以替換為其分解項目,而不會變更程式語意。如果內嵌分解作業無法提供相同的運算子語義,建議使用 custom_call

version 欄位 (預設為 0) 用於表示組合的語意變更時間。

輸入

標籤 名稱 類型
(I1)。 inputs 值的變化數
(I2) name string 類型的常數
(I3) composite_attributes 屬性字典
(I4)。 decomposition 類型為 string 的常數
(I5) version 類型為 si32 的常數

輸出內容

名稱 類型
results 值的變化數

限制

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

範例

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

更多範例

concatenate

語義學

按照與指定引數相同的順序,沿著 dimension 維度連接 inputs,並產生 result 張量。更正式的說法是 result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1],其中:

  1. id = d0 + ... + dk-1 + kd
  2. d 等於 dimension,而 d0、... 是 inputsd 維度大小。

輸入

標籤 名稱 類型 限制
(I1) inputs 可變數張量數量或每個張量量化張量 (C1-C6)
(I2) dimension 類型為 si64 的常數 (C2)、(C4)、(C6)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C5-C6)

限制

  • (C1) same(element_type(inputs...))
  • (C2) same(shape(inputs...)) (dim(inputs..., dimension) 除外)。
  • (C3) 0 < size(inputs)
  • (C4) 0 <= dimension < rank(inputs[0])
  • (C5) element_type(result) = element_type(inputs[0])
  • (C6) shape(result) = shape(inputs[0]) 除下列情況外:
    • dim(result, dimension) = dim(inputs[0], dimension) + ...

範例

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

更多範例

常數

語義學

從常數 value 產生 output 張量。

輸入

標籤 名稱 類型 限制
(I1)。 value 常數 (C1)。

輸出內容

名稱 類型 限制
output 張量或量化張量 (C1)

限制

  • (C1) type(value) = type(output)

範例

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

更多範例

完成轉換

語義學

這個外掛程式能在 operand 張量執行元素類型之間的元素轉換,並產生 result 張量。

對於布林值至任何支援類型轉換,值 false 會轉換為零,而值 true 會轉換為一。對於any-supported-type-to-boolean的轉換,零值會轉換為 false,非零值會轉換為 true。請參閱下文,瞭解此做法如何適用於複雜類型。

如果轉換作業涉及整數轉整數整數轉浮點浮點轉浮點,如果來源值可在目的地類型中精確表示,結果值就是該精確表示法。否則,行為未定 (#180)。

針對涉及floating-point-to-integer的轉換,系統會截斷分數部分。如果截斷的值無法以目的地類型表示,行為為待定 (#180)。

涉及複雜至複雜的轉換會採用與浮點至浮點轉換相同的行為模式,轉換真實與虛構部分。

對於複雜至任何其他類型任何其他類型至複雜轉換,系統會分別忽略來源虛構值或將目的地虛構值設為零。實部份的轉換會遵循浮點轉換。

原則上,這個運算可表示去量化 (從量化張量轉換為一般張量)、量化 (從一般張量轉換為量化張量) 和重新量化 (在量化張量之間轉換),但目前我們有專屬的運算可用於此 - uniform_dequantize 適用於第一個用途,uniform_quantize 適用於第二個和第三個用途。日後,這兩個運算可能會合併為 convert (#1576)。

輸入

標籤 名稱 類型 限制
(I1)。 operand 張量 (C1)

輸出內容

名稱 類型 限制
result 張量 (C1)

限制

  • (C1) shape(operand) = shape(result)

範例

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

更多範例

卷積

語義學

計算 lhs 視窗與 rhs 切片之間的內積,並產生 result。下圖以具體範例說明如何從 lhsrhs 計算 result 中的元素。

卷積

更正式地說,請考慮以下以 lhs 為準的輸入內容重構,以便表達 lhs 的視窗:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)

這項重構會使用下列輔助函式:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] 其中 j[d] = i[permutation[d]]

如果 feature_group_count = 1batch_group_count = 1,則針對 index_space(dim(result, output_spatial_dimensions...)) 中的所有 output_spatial_indexresult[result_shape(:, output_spatial_index, :)] = dot_product 為:

  • padding_value = constant(0, element_type(lhs))
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。這項功能似乎未使用,因此我們預計日後會移除這項功能 (#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])

如果 feature_group_count > 1

  • lhses = split(lhs, feature_group_count, input_feature_dimension)
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

如果 batch_group_count > 1

  • lhses = split(lhs, batch_group_count, input_batch_dimension)
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

針對量化類型,執行 dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))

針對混合量化類型,執行 hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs)

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1)、(C10-C11)、(C14)、(C25)、(C27-C28)、(C31-C32)、(C34)
(I2)。 rhs 張量或量化張量 (C1)、(C14-C16)、(C25)、(C27-C29)、(C31-C34)
(I3) window_strides 1 維 si64 類型張量常數 (C2-C3)、(C25)
(I4)。 padding 類型為 si64 的 2 維張量常數 (C4)、(C25)
(I5) lhs_dilation 1 維 si64 類型張量常數 (C5-C6)、(C25)
(I6)。 rhs_dilation 1 維 si64 類型張量常數 (C7-C8)、(C25)
(I7) window_reversal 1 維 i1 類型張量常數 (C9)。
(I8) input_batch_dimension 類型為 si64 的常數 (C10)、(C13)、(C25)
(I9)。 input_feature_dimension 類型為 si64 的常數 (C11)、(C13-C14)
(I10) input_spatial_dimensions 1 維 si64 類型張量常數 (C12)、(C13)、(C25)
(I11) kernel_input_feature_dimension 類型為 si64 的常數 (C14)、(C18)
(I12) kernel_output_feature_dimension 類型為 si64 的常數 (C15-C16)、(C18)、(C25)、(C29)
(I13) kernel_spatial_dimensions 1 維 si64 類型張量常數 (C17-C18)、(C25)
(I14) output_batch_dimension 類型為 si64 的常數 (C20)、(C25)
(I15)。 output_feature_dimension 類型為 si64 的常數 (C20)、(C25)、(C30)
(I16) output_spatial_dimensions 1 維 si64 類型張量常數 (C19-C20)、(C25)
(I17) feature_group_count si64 類型的常數 (C11)、(C14)、(C16)、(C21)、(C23)
(I18)。 batch_group_count 類型為 si64 的常數 (C10)、(C15)、(C22)、(C23)、(C25)
(I19)。 precision_config DEFAULTHIGHHIGHEST 列舉的變數數量 (C24)。

輸出內容

名稱 類型 限制
result 張量或量化張量 (C25-C28)、(C30)、(C32-34)

限制

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 假設 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) Given kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 假設 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定義為:
    • result_dim = output_batch_dimension 則設為 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,則為 dim(rhs, kernel_output_feature_dimension)
    • num_windows 否則:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果運算使用非量化的張量:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果作業使用量化張量:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs),則 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果是 is_per_axis_quantized(result),則 quantization_dimension(result) = output_feature_dimension
    • 如果 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),則為 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

範例

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

更多範例

餘弦

語義學

operand 張量執行元素級別餘弦運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 cos
  • 適用於複數:複數。
  • 量化類型:dequantize_op_quantize(cosine, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

更多範例

count_leading_zeros

語義學

按元素計算 operand 張前方的零位元數,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) operand 整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數類型的張量 (C1)

限制

  • (C1) type(operand) = type(result)

範例

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

更多範例

custom_call

語義學

封裝實作定義的作業 call_target_name,這會使用 inputscalled_computations 並產生 resultshas_side_effectbackend_configapi_version 可用於提供其他實作定義的中繼資料。

目前,此運算包含相當雜亂的中繼資料集合,反映了 XLA 編譯器中對應運算的自然演進。我們計劃日後要統合這個中繼資料 (#741)。

輸入

標籤 名稱 類型
(I1)。 inputs 值的變化數
(I2) call_target_name string 類型的常數
(I3) has_side_effect 類型為 i1 的常數
(I4) backend_config string 類型的常數或屬性字典
(I5) api_version 類型為 si32 的常數
(I6)。 called_computations 類型為 string 的常數變數參數數量

輸出內容

名稱 類型
results 值的變化數

範例

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

除號

語義學

對除數 lhs 和除數 rhs 張量執行元素逐元素除法,並產生 result 張量。視元素類型而定,執行下列操作:

  • 整數:整數除法,產生代數商,並捨棄任何小數部分。
  • 浮點值:IEEE-754 的 division
  • 複數:複數除法。
  • 針對量化型別:
    • dequantize_op_quantize(divide, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點數或複雜型別或每個張量的量化張量 (C1)
(I2) rhs 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

更多範例

dot_general

語義學

計算 lhs 切片和 rhs 切片之間的內積,並產生 result 張量。

更正式的說法是 result[result_index] = dot_product,其中:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
  • result_batching_index + result_lhs_index + result_rhs_index = result_index,其中 size(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

針對量化類型,執行 dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))

針對混合量化類型,執行 hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs)

precision_config 會控制加速器後端運算的速度與準確度之間的權衡。可以是下列其中一個值 (目前這些列舉值的語意未明確指定,但我們預計會在 #755 中解決這個問題):

  • DEFAULT:計算速度最快,但近似於原始數字的近似值不精確。
  • HIGH:計算速度較慢,但可更準確地估算原始數字。
  • HIGHEST:計算速度最慢,但最接近原始數字。

DotAlgorithm 會定義用於實作點運算法的必要屬性,並定義精確度。如果設定了演算法屬性欄位,則 precision_config 必須為 DEFAULTDotAlgorithms 沒有預設值,因為預設參數是由實作定義。因此,所有點狀演算法欄位都可能設為 None,以指定空白點狀演算法,而該演算法會改用 precision_config 值。

DotAlgorithm 欄位包括:

  • lhs_precision_typerhs_precision_type,作業的 LHS 和 RHS 會四捨五入至此精確度。精確度類型與輸入和輸出內容的儲存類型無關。
  • accumulation_type 用於累加的精確度。
  • lhs_component_countrhs_component_countnum_primitive_operations 的運算方法適用於將 LHS 和/或 RHS 分解為多個元件,並對這些值執行多項「原始」點運算時,通常用於模擬較高的精確度 (例如運用 bfloat16 bfloat16 picial Intelligence Datatype for Higher-Precision tf2put6_6:如果是沒有分解的演算法,則應將這些值設為 1
  • allow_imprecise_accumulation 可用於指定是否允許在某些步驟中使用較低精確度的累加運算 (例如 CUBLASLT_MATMUL_DESC_FAST_ACCUM)。

DotAlgorithm 屬性範例:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

實作方式會決定支援哪些組合。一般而言,無法保證 StableHLO 消費者支援各種加速器類型。如果不支援特定演算法,則應引發錯誤,而非改回使用替代方法。StableHLO 驗證會盡力驗證,避免在任何硬體上支援未知的演算法。

如要瞭解部分支援的演算法值,請參閱 xla_data.proto > Algorithm。支援單 #2483 記錄了建立一份關於後端支援演算法的集中式文件的計畫。

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C5-C6)、(C9-C10)、(C12-C14)、(C17-C18)、(C20)
(I2)。 rhs 張量或量化張量 (C7-C10)、(C12-C20)
(I3) lhs_batching_dimensions 1 維 si64 類型張量常數 (C1)、(C3)、(C5)、(C9)、(C12)
(I4) rhs_batching_dimensions 1 維 si64 類型張量常數 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions si64 類型的 1D 張量常數 (C2)、(C3)、(C6)、(C10)
(I6) rhs_contracting_dimensions 1 維 si64 類型張量常數 (C2)、(C4)、(C8)、(C10)、(C16)
(I7)。 precision_config DEFAULTHIGHHIGHEST 的列舉變數數量 (C11)、(C21)
(I8) lhs_precision_type FloatType 或 TensorFloat32 (C21)。
(I9) rhs_precision_type FloatType 或 TensorFloat32 (C21)
(I10) accumulation_type FloatType 或 TensorFloat32 (C21)
(I11)。 lhs_component_count 類型為 si32 的常數 (C21)、(C22)
(I12) rhs_component_count 類型為 si32 的常數 (C21)、(C23)
(I13) num_primitive_operations 類型為 si32 的常數 (C21)、(C24)
(I14) allow_imprecise_accumulation 類型為 bool 的常數 (C21)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C12)、(C14)、(C18-C20)

限制

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs)
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs)
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11) size(precision_config) = 2
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • 如果作業使用非量化張量:
    • (C13) element_type(lhs) = element_type(rhs)
  • 如果運算使用經過量化的張量:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C15) zero_points(rhs) = 0
    • (C16) 如果 is_per_axis_quantized(rhs),則 quantization_dimension(rhs) 不在 rhs_contracting_dimensions 中。
    • 如果 is_quantized(lhs)
    • (C17) storage_type(lhs) = storage_type(rhs)
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C19) 如果 is_per_tensor_quantized(rhs),則為 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result)
  • 如果 !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
    • (C21) precision_config... = DEFAULT
    • (C22) 0 < lhs_component_count
    • (C23) 0 < rhs_component_count
    • (C24) 0 < num_primitive_operations

範例

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

更多範例

dynamic_broadcast_in_dim

語義學

這項作業的功能與 broadcast_in_dim 作業相同,但結果形狀會透過 output_dimensions 動態指定。

這項作業也接受選用屬性 known_expanding_dimensionsknown_nonexpanding_dimensions,用來表示有關維度展開行為的靜態知識。如果未指定,系統會假設所有維度都可能會展開。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C2)、(C5-C6)、(C9)
(I2)。 output_dimensions 整數類型的 1 維張量 (C7)。
(I3) broadcast_dimensions 整數類型的 1D 常數張量 (C2-C6)
(I4) known_expanding_dimensions 整數類型的 1 維常數張量 (C8-C9)
(I5) known_nonexpanding_dimensions 整數類型的 1 維常數張量 (C8-C9)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1)、(C3)、(C5-C7)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand),如果 !is_per_axis_quantized(operand)
    • element_type(operand)quantization_dimension(operand)scales(operand)zero_points(operand) 可能與 quantization_dimension(result)scales(result)zero_points(result) 之間不同。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 針對 axes(operand) 中的所有 d
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result)
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果是 dim(operand, quantization_dimension(operand)) = 1,則 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
  • (C7) size(output_dimensions) = rank(result)
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
  • (C9) 0 <= known_expanding_dimensions < rank(operand)
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand)

範例

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多範例

dynamic_conv

語義學

這項作業的功能與卷積作業相同,但邊框會透過 padding 動態指定。

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1)、(C10-C11)、(C14)、(C25)、(C26-C27)、(C30-C31)、(C33)
(I2)。 rhs 張量或量化張量 (C1)、(C14-C16)、(C26-C28)、(C30-C33)
(I3)。 padding 整數類型的 2 維張量 (C4)
(I4) window_strides 1 維 si64 類型張量常數 (C2-C3)
(I5) lhs_dilation 1 維 si64 類型張量常數 (C5-C6)
(I6) rhs_dilation 1 維 si64 類型張量常數 (C7-C8)
(I7) window_reversal 1 維 i1 類型張量常數 (C9)。
(I8) input_batch_dimension 類型為 si64 的常數 (C10)、(C13)
(I9) input_feature_dimension 類型為 si64 的常數 (C11)、(C13-C14)
(I10) input_spatial_dimensions 1 維 si64 類型張量常數 (C12)、(C13)
(I11) kernel_input_feature_dimension 類型為 si64 的常數 (C14)、(C18)
(I12) kernel_output_feature_dimension 類型為 si64 的常數 (C15-C16)、(C18)、(C28)
(I13) kernel_spatial_dimensions 1 維 si64 類型張量常數 (C17-C18)
(I14) output_batch_dimension 類型為 si64 的常數 (C20)
(I15) output_feature_dimension si64 類型的常數 (C20)、(C29)
(I16)。 output_spatial_dimensions 1 維 si64 類型張量常數 (C19-C20)
(I17) feature_group_count si64 類型的常數 (C11)、(C14)、(C16)、(C21)、(C23)
(I18)。 batch_group_count si64 類型的常數 (C10)、(C15)、(C22)、(C23)
(I19) precision_config DEFAULTHIGHHIGHEST 列舉的變數數量 (C24)。

輸出內容

名稱 類型 限制
result 張量或量化張量 (C25-C27)、(C29)、(C31-C33)

限制

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 假設 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) Given kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 假設 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定義為:
    • result_dim = output_batch_dimension 則設為 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,則為 dim(rhs, kernel_output_feature_dimension)
    • num_windows 否則:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果運算使用非量化的張量:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果作業使用量化張量:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs),則 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果是 is_per_axis_quantized(result),則 quantization_dimension(result) = output_feature_dimension
    • 如果 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),則為 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

範例

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

更多範例

dynamic_gather

語義學

這個運算的功能與 gather 運算相同,但 slice_sizes 會以值的形式動態指定。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C7)、(C10-C12)、(C14)
(I2)。 start_indices 整數類型的張量 (C2)、(C3)、(C13)
(I3) slice_sizes 整數類型的 1 維張量 (C8)、(C11-C13)
(I4) offset_dims 1 維 si64 類型張量常數 (C1)、(C4-C5)、(C13)
(I5) collapsed_slice_dims 1 維 si64 類型張量常數 (C1)、(C6-C8)、(C13)
(I6) start_index_map 1 維 si64 類型張量常數 (C3)、(C9)、(C10)
(I7)。 index_vector_dim 類型為 si64 的常數 (C2)、(C3)、(C13)
(I8)。 indices_are_sorted i1 類型的常數

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C5)、(C13-C14)

限制

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7) 0 <= collapsed_slice_dims < rank(operand)
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1
  • (C9) is_unique(start_index_map)
  • (C10) 0 <= start_index_map < rank(operand)
  • (C11) size(slice_sizes) = rank(operand)
  • (C12) 0 <= slice_sizes <= shape(operand)
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) 其中:
    • batch_dim_sizes = shape(start_indices),但不含與 index_vector_dim 對應的 start_indices 維度大小。
    • offset_dim_sizes = shape(slice_sizes),但不包含 slice_sizes 中對應 collapsed_slice_dims 的維度大小。
    • combine 會將 batch_dim_sizes 放在對應 batch_dims 的軸上,並將 offset_dim_sizes 放在對應 offset_dims 的軸上。
  • (C14) element_type(operand) = element_type(result)

範例

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

更多範例

dynamic_iota

語義學

這個運算的功能與 iota 運算相同,但結果形狀會透過 output_shape 動態指定。

輸入

標籤 名稱 類型 限制
(I1) output_shape 整數類型的 1 維張量 (C1)、(C2)
(I2) iota_dimension si64 (C1)

輸出內容

名稱 類型 限制
result 整數、浮點或複數類型的張量,或每個張量的量化張量 (C2)

限制

  • (C1) 0 <= iota_dimension < size(output_shape)
  • (C2) rank(result) = size(output_shape)

範例

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

更多範例

dynamic_pad

語義學

這個作業的功能與 pad 作業相同,但 edge_padding_lowedge_padding_highinterior_padding 會以值的形式動態指定。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C2)、(C4)
(I2) padding_value 0 維張量或每個張量量化張量 (C1)
(I3) edge_padding_low 整數類型的 1 維張量 (C1)、(C4)
(I4)。 edge_padding_high 整數類型的 1 維張量 (C1)、(C4)
(I5) interior_padding 整數類型的 1 維張量 (C2-C4)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C3-C6)

限制

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

範例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多範例

dynamic_reshape

語義學

這個運算的功能與 reshape 運算相同,但結果形狀會透過 output_shape 動態指定。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C3)
(I2) output_shape 整數類型的 1 維張量 (C4)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1-C4)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand),如果 !is_per_axis_quantized(operand)
    • element_type(operand),但 quantization_dimension(operand)quantization_dimension(result) 可能不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand)
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
  • (C4) size(output_shape) = rank(result)

範例

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

更多範例

dynamic_slice

語義學

使用動態運算的起始索引從 operand 擷取配量,並產生 result 張量。start_indices 包含每個維度切片的起始索引 (可能會調整),而 slice_sizes 則包含每個維度的切片大小。更正式的說法是 result[result_index] = operand[operand_index],其中:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
  • operand_index = adjusted_start_indices + result_index

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C2)、(C4)
(I2) start_indices 整數類型 0 維度張量的變數參數數量 (C2)、(C3)
(I3) slice_sizes 1 維 si64 類型張量常數 (C2)、(C4)、(C5)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)、(C5)

限制

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3) same(type(start_indices...))
  • (C4) 0 <= slice_sizes <= shape(operand)
  • (C5) shape(result) = slice_sizes

範例

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

更多範例

dynamic_update_slice

語義學

產生與 operand 張相等的 result 張量,但從 start_indices 開始的切片會更新為 update 中的值。更正式的說法,result[result_index] 的定義為:

  • update[update_index] 如果 0 <= update_index < shape(update) 符合下列條件:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
    • update_index = result_index - adjusted_start_indices
  • 否則傳回 operand[result_index]

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1-C4)、(C6)
(I2)。 update 張量或每個張量量化張量 (C2)、(C3)、(C6)
(I3) start_indices 整數類型 0 維度張量的變數參數數量 (C4)、(C5)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)。

限制

  • (C1) type(operand) = type(result)
  • (C2) element_type(update) = element_type(operand)
  • (C3) rank(update) = rank(operand)
  • (C4) size(start_indices) = rank(operand)
  • (C5) same(type(start_indices...))
  • (C6) 0 <= shape(update) <= shape(operand)

範例

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

更多範例

指數

語義學

operand 張量執行元素級別指數運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 exp
  • 複數:複數指數。
  • 量化類型:dequantize_op_quantize(exponential, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

更多範例

exponential_minus_one

語義學

operand 張量執行元素級別指數減一運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 expm1
  • 複數:複數指數減一。
  • 量化型別:dequantize_op_quantize(exponential_minus_one, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

更多範例

fft

語義學

針對實數和複數輸入/輸出執行正向和反向傅立葉變換。

fft_type 是下列其中一項:

  • FFT:轉送複雜至複雜的 FFT。
  • IFFT:複數對複數的逆 FFT。
  • RFFT:正向實數至複數 FFT。
  • IRFFT:反向實數至複數 FFT (即取複數,傳回實數)。

更正式地說,假設函式 fft 會將複雜型別的 1 維張量陣列做為輸入,產生相同類型的 1 維張量陣列做為輸出,並計算離散傅立葉轉換:

對於 fft_type = FFTresult 的定義為一系列 L 運算的最終結果,其中 L = size(fft_length)。例如 L = 3

  • result1[i0, ..., :] = fft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

此外,請考慮具有相同類型簽章的函式 ifft,該函式會計算 fft 的倒數:

如果是 fft_type = IFFTresult 的定義為 fft_type = FFT 的反向運算。以 L = 3 為例:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :])

此外,如果函式 rfft 會接受浮點型別的 1 維張量,則會產生相同浮點語義的複雜型別 1 維張量,並運作如下:

  • rfft(real_operand) = truncated_result,其中
  • complex_operand... = (real_operand..., 0.0)
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]

(當您為實際運算元計算不同的 Fourier 轉換時,結果的第一個 N/2 + 1 元素會明確定義結果的其餘部分,因此會截斷 rfft 的結果,以免運算多餘的元素。

對於 fft_type = RFFTresult 的定義為一系列 L 運算的最終結果,其中 L = size(fft_length)。以 L = 3 為例:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

最後,請考慮具有相同類型簽章的函式 irfft,該函式會計算 rfft 的倒數:

對於 fft_type = IRFFTresult 定義為 fft_type = RFFT 運算的倒數。例如 L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :])

輸入

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)、(C4)、(C5)
(I2)。 fft_type FFTIFFTRFFTIRFFT 的列舉 (C2)、(C5)
(I3)。 fft_length 1 維 si64 類型張量常數 (C1)、(C3)、(C4)

輸出內容

名稱 類型 限制
result 浮點或複雜類型的張量 (C2)、(C4)、(C5)

限制

  • (C1) size(fft_length) <= rank(operand)
  • (C2) operandresult 元素類型之間的關係有所不同:
    • 如果 fft_type = FFTelement_type(operand)element_type(result) 具有相同的複雜類型。
    • 如果 fft_type = IFFTelement_type(operand)element_type(result) 具有相同的複雜類型,
    • 如果 fft_type = RFFTelement_type(operand) 是浮點類型,且 element_type(result) 是相同浮點語意的複雜類型。
    • 如果 fft_type = IRFFTelement_type(operand) 是複雜類型,且 element_type(result) 是相同浮點語意的浮點類型。
  • (C3) 1 <= size(fft_length) <= 3
  • (C4) 如果 operandresult 中有浮點類型張量 real,則為 shape(real)[-size(fft_length):] = fft_length
  • (C5) shape(result) = shape(operand),但下列情況除外:
    • 如果是 fft_type = RFFT,則為 dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • 如果是 fft_type = IRFFTdim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

範例

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

語義學

執行 operand 張元素階層的底層,並產生 result 張量。依據 IEEE-754 規範實作 roundToIntegralTowardNegative 作業。針對量化類型,執行 dequantize_op_quantize(floor, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點類型的張量或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

更多範例

gather

語義學

start_indices 中指定的偏移值從 operand 張量收集切片,並產生 result 張量。

以下圖表以具體範例說明 result 中的元素如何對應至 operand 中的元素。此圖表會挑選幾個 result 索引範例,並詳細說明這些索引對應的 operand 索引。

gather

更正式的result[result_index] = operand[operand_index],其中:

  • batch_dims = [d for d in axes(result) and d not in offset_dims]
  • batch_index = result_index[batch_dims...]
  • start_index 的定義如下:
    • start_indices[bi0, ..., :, ..., biN],其中 bibatch_index 中的個別元素,如果 index_vector_dim < rank(start_indices): 會插入 index_vector_dim 索引。
    • 否則傳回 [start_indices[batch_index]]
  • axes(operand) 中的 d_operand
    • 如果為 d_operand = start_index_map[d_start],則為 full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
    • 否則傳回 full_start_index[d_operand] = 0
  • 針對 axes(operand) 中的 d_operand
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)],如果 d_operand = operand_batching_dims[i_batching]d_start = start_indices_batching_dims[i_batching]
    • 否則傳回 full_batching_index[d_operand] = 0
  • offset_index = result_index[offset_dims...]
  • full_offset_index = [oi0, ..., 0, ..., oiN],其中 oioffset_index 中的個別元素,而 0 則會插入 collapsed_slice_dimsoperand_batching_dims 的索引。
  • operand_index = full_start_index + full_batching_index + full_offset_index

如果 indices_are_sortedtrue,實作可以假設 start_indices 是根據 start_index_map 排序,否則行為未定義。更正式地說,對於 indices(result) 的所有 i1 < i2full_start_index(i1) <= full_start_index(i2)

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C8)、(C11)、(C17)、(C19-C21)、(C23)
(I2) start_indices 整數類型的張量 (C2-C3)、(C14)、(C17)、(C22)
(I3)。 offset_dims 1 維 si64 類型張量常數 (C1)、(C4-C5)、(C22)
(I4)。 collapsed_slice_dims 1 維 si64 類型張量常數 (C1)、(C6-C9)、(C22)
(I5)。 operand_batching_dims si64 類型的 1D 張量常數 (C1)、(C6)、(C10-C12)、(C16-C18)、(C22)
(I6) start_indices_batching_dims 1 維 si64 類型張量常數 (C13-C17)
(I7) start_index_map 1 維 si64 類型張量常數 (C3)、(C18-C19)
(I8)。 index_vector_dim 類型為 si64 的常數 (C2-C3)、(C15)、(C22)
(I9) slice_sizes 1 維 si64 類型張量常數 (C9)、(C12)、(C20-C22)
(I10) indices_are_sorted 類型為 i1 的常數

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C5)、(C22-C23)

限制

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims)
  • (C8) 0 <= collapsed_slice_dims < rank(operand)
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1
  • (C10) is_sorted(operand_batching_dims)
  • (C11) 0 <= operand_batching_dims < rank(operand)
  • (C12) slice_sizes[operand_batching_dims...] <= 1
  • (C13) is_unique(start_indices_batching_dims)
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices)
  • (C15) index_vector_dim not in start_indices_batching_dims
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims)
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims))
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand)
  • (C21) 0 <= slice_sizes <= shape(operand)
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) 其中:
    • batch_dim_sizes = shape(start_indices),但不含與 index_vector_dim 對應的 start_indices 維度大小。
    • offset_dim_sizes = slice_sizes,但不包含 slice_sizes 中對應 collapsed_slice_dimsoperand_batching_dims 的維度大小。
    • combine 會將 batch_dim_sizes 放在對應 offset_dims 的軸上,位於與 batch_dimsoffset_dim_sizes 相對應的軸位置。
  • (C23) element_type(operand) = element_type(result)

範例

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

更多範例

get_dimension_size

語義學

產生 operand 的指定 dimension 大小。更正式的說法是 result = dim(operand, dimension)。語意學只會處理型別的形狀元件。元素類型可以是任何內容。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1)
(I2)。 dimension 類型為 si64 的常數 (C1)。

輸出內容

名稱 類型
result si32 類型的 0 維度張量

限制

  • (C1) 0 <= dimension < rank(operand)

範例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

更多範例

get_tuple_element

語義學

擷取位於 operand 元組的 index 位置的元素,並產生 result。更正式的說法是 result = operand[index]

輸入

標籤 名稱 類型 限制
(I1) operand 元組 (C1)、(C2)
(I2) index si32 類型的常數 (C1)、(C2)

輸出內容

名稱 類型 限制
result 任何支援的類型 (C2)

限制

  • (C1) 0 <= index < size(operand)
  • (C2) type(result) = tuple_element_types(operand)[index]

範例

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

更多範例

如果

語義學

根據 pred 的值,產生執行 true_branchfalse_branch 中正確一個函式的輸出內容。更正式的說法是 result = pred ? true_branch() : false_branch()

輸入

標籤 名稱 類型 限制
(I1) pred i1 類型的 0 維度張量
(I2) true_branch 函式 (C1-C3)
(I3) false_branch 函式 (C1)、(C2)

輸出內容

名稱 類型 限制
results 變異量、量化張量或代詞 (C3)

限制

  • (C1) input_types(true_branch) = input_types(false_branch) = []
  • (C2) output_types(true_branch) = output_types(false_branch)
  • (C3) type(results...) = output_types(true_branch)

範例

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

更多範例

imag

語義學

operand 逐元素擷取虛數部分,並產生 result 張量。更正式地說,對於每個元素 ximag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))

輸入

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)

輸出內容

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 定義為:
    • is_complex(operand) 則設為 complex_element_type(element_type(operand))
    • 否則傳回 element_type(operand)

範例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

更多範例

動態內廣告

語義學

讀取內嵌動態饋給中的資料,並產生 results

infeed_config 的語意是由實作方式定義。

results 包含酬載值 (放在最前面) 和符記 (放在最後)。我們預計日後將酬載和符記分成兩個獨立的輸出內容,以便清楚顯示 (#670)。

輸入

標籤 名稱 類型
(I1) token token
(I2) infeed_config string 類型的常數

輸出內容

名稱 類型 限制
results 張量、量化張量或符記的變數參數數量 (C1-C3)

限制

  • (C1) 0 < size(results)
  • (C2) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C3) is_token(type(results[-1]))

範例

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

更多範例

iota

語義學

iota_dimension 維度從零開始,依照遞增順序填入 output 張量。更正式的說法是:

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))

輸入

標籤 名稱 類型 限制
(I1) iota_dimension si64 (C1)。

輸出內容

名稱 類型 限制
output 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

限制

  • (C1) 0 <= iota_dimension < rank(output)

範例

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

更多範例

is_finite

語義學

依據元素執行檢查 x 中的值是否為有限 (即 +Inf、-Inf 或 NaN) 並產生 y 張量。實作 IEEE-754 規格中的 isFinite 運算。對於量化類型,結果一律為 true

輸入

標籤 名稱 類型 限制
(I1) x 浮點類型的張量或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
y 布林值類型的張量 (C1)

限制

  • (C1) shape(x) = shape(y)

範例

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

更多範例

log

語義學

operand 張執行元素的對數運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 log
  • 複數:複對數對數。
  • 量化類型:dequantize_op_quantize(log, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

更多範例

log_plus_one

語義學

operand 張量執行元素級別對數加一項運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 logp1
  • 複數:複數對數加一。
  • 量化類型:dequantize_op_quantize(log_plus_one, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

更多範例

物流

語義學

operand 張量執行元素級別邏輯運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 division(1, addition(1, exp(-x)))
  • 複數:複雜的邏輯。
  • 量化型別:dequantize_op_quantize(logistic, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

更多範例

地圖

語義學

將對應函式 computation 同時套用至 dimensionsinputs,並產生 result 張量。

更正式,result[result_index] = computation(inputs...[result_index])

輸入

標籤 名稱 類型 限制
(I1) inputs 可變數張量數量或每個張量量化張量 (C1-C4)
(I2) dimensions 1 維 si64 類型張量常數 (C3)
(I3) computation 函式 (C4)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)、(C4)

限制

  • (C1) shape(inputs...) = shape(result)
  • (C2) 0 < size(inputs) = N
  • (C3) dimensions = range(rank(inputs[0]))
  • (C4) computation 具有 (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> 類型,其中 Ei = element_type(inputs[i])E' = element_type(result)

範例

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

更多範例

最高

語義學

對張量 lhsrhs 執行元素級別最大值運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 OR。
  • 整數:最大整數。
  • 浮點值:IEEE-754 的 maximum
  • 複數:(real, imaginary) 組合的字典順序最大值。對複雜數字表示排序涉及令人意想不到的語意,因此我們預計日後會移除對此運算的複雜數字的支援 (#560)。
  • 針對量化型別:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1)
(I2) rhs 張量或每個張量量化張量 (C1)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

更多範例

最低

語義學

對張量 lhsrhs 執行元素下限運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 AND。
  • 整數:整數下限。
  • 浮點值:IEEE-754 的 minimum
  • 針對複數:(real, imaginary) 組合的字母數字最小值。對複數強制排序會涉及意料之外的語意,因此我們未來預計會移除此運算對複數的支援 (#560)。
  • 針對量化型別:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1)
(I2) rhs 張量或每個張量量化張量 (C1)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

更多範例

相乘

語義學

執行兩個張量 lhsrhs 的元素相乘運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 AND。
  • 整數:整數相乘。
  • 浮點值:IEEE-754 的 multiplication
  • 複數:複數相乘。
  • 量化類型:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 張量或每個張量量化張量 (C1)
(I2) rhs 張量或每個張量量化張量 (C1)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

更多範例

negate

語義學

operand 張量執行元素逐元素否定,並產生 result 張量。視元素類型而定,執行下列操作:

  • 如為正負號整數:整數否定。
  • 對於無符號整數:位元轉換為帶正負號整數、整數反轉、位元轉換回無符號整數。
  • 浮點值:IEEE-754 的 negate
  • 複數:複數否定。
  • 量化類型:dequantize_op_quantize(negate, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)。

輸出內容

名稱 類型 限制
result 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

更多範例

not

語義學

針對張量 operand 執行元素逐元素 NOT 運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 NOT。
  • 整數:位元 NOT。

引數

名稱 類型 限制
operand 布林值或整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(operand) = type(result)

範例

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

更多範例

optimization_barrier

語義學

確保產生 operand 的作業會在任何依賴 result 的作業之前執行,並防止編譯器轉換作業跨越障礙。此外,作業為身分,即 result = operand

引數

名稱 類型 限制
operand 變異量、每個張量量化張量或代詞 (C1)

輸出內容

名稱 類型 限制
result 變異量、每個張量量化張量或代詞 (C1)

限制

  • (C1) type(operand...) = type(result...)

範例

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

更多範例

語義學

對兩個張量 lhsrhs 執行元素 OR 運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 OR。
  • 整數:位元 OR。

輸入

標籤 名稱 類型 限制
(I1) lhs 整數或布林值類型的張量 (C1)
(I2)。 rhs 整數或布林值類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數或布林值類型的張量 (C1)

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

更多範例

outfeed

語義學

inputs 寫入傳出動態饋給,並產生 result 權杖。

outfeed_config 的語意是由實作定義。

輸入

標籤 名稱 類型
(I1) inputs 張量或量化張量的變數參數數量
(I2) token token
(I3) outfeed_config 類型為 string 的常數

輸出內容

名稱 類型
result token

範例

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多範例

墊片

語義學

使用指定的 padding_value 在張量周圍和張量元素之間填入邊框,藉此展開 operand

edge_padding_lowedge_padding_high 會分別指定在各維度的低端 (位於索引 0 旁邊) 和高階 (位於最高索引旁邊) 加上的邊框間距量。邊框的大小可以為負值,負邊框的絕對值表示從指定維度移除的元素數量。

interior_padding 會指定在每個維度的任兩個元素之間加入的邊框間距數量,且不得為負值。內部邊框間距會先於邊框間距,因此負邊框間距會從內部邊框間距運算元中移除元素。

result[result_index] 的正式定義如下:

  • result_index = edge_padding_low + operand_index * (interior_padding + 1) 則設為 operand[operand_index]
  • 否則傳回 padding_value

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C2)、(C4)
(I2) padding_value 0 維張量或每個張量量化張量 (C1)
(I3)。 edge_padding_low 1 維 si64 類型張量常數 (C1)、(C4)
(I4)。 edge_padding_high 1 維 si64 類型張量常數 (C1)、(C4)
(I5) interior_padding si64 類型的 1D 張量常數 (C2-C4)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C3-C6)

限制

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

範例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多範例

partition_id

語義學

產生目前程序的 partition_id

輸出內容

名稱 類型
result ui32 類型的 0 維度張量

範例

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

更多範例

popcnt

語義學

逐元素計算 operand 張量中設定的位元數,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) operand 整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數類型的張量 (C1)

限制

  • (C1) type(operand) = type(result)

範例

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

更多範例

指數

語義學

lhs 張量使用 rhs 張量,以元素為單位執行指數運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 整數:整數指數運算。
  • 浮點值:IEEE-754 的 pow
  • 複數:複數指數
  • 對於量化類型:dequantize_op_quantize(power, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)。
(I2) rhs 整數張量、浮點數或複雜類型,或每個張量的量化張量 (C1)。

輸出內容

名稱 類型 限制
result 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

更多範例

real

語義學

operand 依據元素擷取實際部分,並產生 result 張量。更正式地說,對於每個元素 xreal(x) = is_complex(x) ? real_part(x) : x

輸入

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)

輸出內容

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 定義為:
    • is_complex(operand) 則設為 complex_element_type(element_type(operand))
    • 否則傳回 element_type(operand)

範例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

更多範例

recv

語義學

這個外掛程式能接收採用 channel_id 的頻道資料,並產生 results

如果 is_host_transfertrue,則作業會從主機傳輸資料。否則,系統會從其他裝置轉移資料。這表示實作方式由實作端定義。這個標記與 channel_type 中提供的資訊重複,因此我們日後打算只保留其中一個 (#666)。

results 包含酬載值 (放在最前面) 和符記 (放在最後)。我們預計日後將酬載和符記分成兩個獨立的輸出內容,以便清楚顯示 (#670)。

輸入

標籤 名稱 類型 限制
(I1) token token (C4)
(I2) channel_id 類型為 si64 的常數
(I3) channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE 的列舉 (C1)
(I4)。 is_host_transfer 類型為 i1 的常數 (C1)

輸出內容

名稱 類型 限制
results 變異量、量化張量或代詞 (C2-C4)

限制

  • (C1) channel_type 的定義為:
    • HOST_TO_DEVICEis_host_transfer = true
    • 否則傳回 DEVICE_TO_DEVICE
  • (C2) 0 < size(results)
  • (C3) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C4) is_token(type(results[-1]))

範例

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

更多範例

reduce

語義學

沿著 dimensions 將縮減函式 body 套用至 inputsinit_values,並產生 results 張量。

縮減順序是由實作定義,也就是說 bodyinit_values 必須形成單元,以確保運算可在所有實作中為所有輸入產生相同的結果。不過,許多常見的縮減作業並未符合這個條件。舉例來說,body 的浮點加法和 init_values 的零實際上並不會形成單元,因為浮點加法不遵守結合律。

更正式的說法是 results...[j0, ..., jR-1] = reduce(input_slices_converted),其中:

  • input_slices = inputs...[j0, ..., :, ..., jR-1],其中 : 會插入 dimensions
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
  • reduce(input_slices_converted) = exec(schedule) 適用於某些二元樹schedule,其中:
    • exec(node) = body(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是實作定義的完整二進位樹狀結構,其順序週遊包含下列內容:
    • input_slices_converted...[index] 值,適用於 index_space(input_slices_converted) 中所有 index 的遞增字典編列順序。
    • 在實作定義的位置中交錯實作定義的 init_values_converted 數量。

輸入

標籤 名稱 類型 限制
(I1) inputs 可變數張量數量或每個張量量化張量 (C1-C4)、(C6)、(C7)
(I2) init_values 單一維度張量或每個張量的量化張量 (C2)、(C3)
(I3) dimensions 1 維 si64 類型張量常數 (C4)、(C5)、(C7)
(I4) body 函式 (C6)

輸出內容

名稱 類型 限制
results 變異量或每個張量量化張量 (C3)、(C7)、(C8)

限制

  • (C1) same(shape(inputs...))
  • (C2) element_type(inputs...) = element_type(init_values...)
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C4) 0 <= dimensions < rank(inputs[0])
  • (C5) is_unique(dimensions)
  • (C6) body 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C7) shape(results...) = shape(inputs...),但不包含與 dimensions 對應的 inputs... 維度大小。
  • (C8) element_type(results[i]) = Ei[0,N) 中的所有 i

範例

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

更多範例

reduce_precision

語義學

這個外掛程式能對使用 exponent_bitsmantissa_bits 的另一個浮點類型執行 operand 的元素轉換,並改回原始浮點類型,並產生 output 張量。

更正式:

  • 原始值的小數位元會更新,將原始值四捨五入至 roundToIntegralTiesToEven 語意中可用 mantissa_bits 表示的最接近值。
  • 接著,如果 mantissa_bits 小於原始值的小數值位元數,則小數值位元會截斷為 mantissa_bits
  • 接著,如果中間結果的指數位元不符合 exponent_bits 提供的範圍,中間結果會以原始符號溢位至無窮大,或以原始符號溢位至零。
  • 針對量化類型,執行 dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)
(I2) exponent_bits 類型為 si32 的常數 (C2)
(I3)。 mantissa_bits 類型為 si32 的常數 (C3)

輸出內容

名稱 類型 限制
output 浮點類型的張量或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(output)
  • (C2) 1 <= exponent_bits
  • (C3) 0 <= mantissa_bits

範例

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

更多範例

reduce_scatter

語義學

reduce_scatter

在 StableHLO 程序格線的每個程序群組中,使用 computations 對每個程序的 operand 張量值執行縮減作業,將縮減結果沿著 scatter_dimension 分割成部分,並在程序之間散布分割部分,產生 result

該作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false
  • 如果為 channel_id > 0 and use_global_device_ids = false,則為 cross_replica_and_partition(replica_groups)
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true

接著,在每個 process_group 中:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
  • result@receiver = parts@sender[receiver_index]process_group 中所有 sender 的值,其中 receiver_index = process_group.index(receiver)

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C2)、(C7)、(C8)
(I2) scatter_dimension 類型為 si64 的常數 (C1)、(C2) 和 (C8)
(I3) replica_groups si64 類型的 2D 張張量常數 (C3-C5)
(I4)。 channel_id 類型為 si64 的常數 (C6)。
(I5)。 use_global_device_ids i1 類型的常數 (C6)
(I6) computation 函式 (C7)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C8-C9)

限制

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2) 0 <= scatter_dimension < rank(operand)
  • (C3) is_unique(replica_groups)
  • (C4) size(replica_groups) 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C5) 0 <= replica_groups < size(replica_groups)
  • (C6) 如果 use_global_device_ids = true,則 channel_id > 0
  • (C7) computation 具有 (tensor<E>, tensor<E>) -> (tensor<E>) 類型,其中 is_promotable(element_type(operand), E)
  • (C8) shape(result) = shape(operand) except:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
  • (C9) element_type(result) = E

範例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

更多範例

reduce_window

語義學

將縮減函式 body 套用至 inputsinit_values 的視窗,並產生 results

下圖以具體範例說明如何從 inputs... 計算 results... 中的元素。

reduce_window

更正式的 results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (請參閱 縮減),其中:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations)

輸入

標籤 名稱 類型 限制
(I1) inputs 變異量或每個張量量化張量 (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2) init_values 0 維度張量或每個張量量化張量的變數參數數量 (C1)、(C13)
(I3)。 window_dimensions si64 類型的 1D 張量常數 (C4)、(C5)、(C15)
(I4) window_strides 1 維 si64 類型張量常數 (C6)、(C7)、(C15)
(I5) base_dilations si64 類型的 1D 張量常數 (C8)、(C9)、(C15)
(I6)。 window_dilations si64 類型的 1D 張量常數 (C10)、(C11)、(C15)
(I7) padding 類型為 si64 的 2 維張量常數 (C12)、(C15)
(I8) body 函式 (C13)。

輸出內容

名稱 類型 限制
results 可變數張量數量或每個張量量化張量 (C1)、(C14-C16)

限制

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C2) same(shape(inputs...))
  • (C3) element_type(inputs...) = element_type(init_values...)
  • (C4) size(window_dimensions) = rank(inputs[0])
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(inputs[0])
  • (C7) 0 < window_strides
  • (C8) size(base_dilations) = rank(inputs[0])
  • (C9) 0 < base_dilations
  • (C10) size(window_dilations) = rank(inputs[0])
  • (C11) 0 < window_dilations
  • (C12) shape(padding) = [rank(inputs[0]), 2]
  • (C13) body 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C14) same(shape(results...))
  • (C15) shape(results[0]) = num_windows where:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
  • (C16) element_type(results[i]) = Ei 適用於 [0,N) 中的所有 i

範例

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

更多範例

其餘

語義學

執行被除數 lhs 和除數 rhs 張元素的餘數,並產生 result 張量。

更正式地說,結果的符號取自除數,且結果的絕對值一律小於除數的絕對值。其餘數的計算方式為 lhs - d * rhs,其中 d 的計算方式為:

  • 整數:stablehlo.divide(lhs, rhs)
  • 浮點值:使用舍入屬性的 IEEE-754 標準 division(lhs, rhs) roundTowardZero
  • 複數:待定 (#997)。
  • 針對量化型別:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result))

對浮點元素類型而言,這項作業與 IEEE-754 規格的 remainder 運算不同,其中 d 是接近 lhs/rhs 與偶數的確切值。

輸入

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點數或複雜型別或每個張量的量化張量 (C1)
(I2) rhs 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 整數、浮點數或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

更多範例

replica_id

語義學

產生目前程序的 replica_id

輸出內容

名稱 類型
result ui32 類型的 0 維度張量

範例

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

更多範例

reshape

語義學

operand 張量重塑為 result 張量。從概念上來說,這等於維持相同的標準表示法,但可能會變更形狀,例如從 tensor<2x3xf32> 變更為 tensor<3x2xf32>tensor<6xf32>

更正式地說,result[result_index] = operand[operand_index]index_space(result)index_space(operand) 的字典順序中,result_indexoperand_index 有相同的位置。

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C3)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1-C3)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand),如果 !is_per_axis_quantized(operand)
    • element_type(operand),但 quantization_dimension(operand)quantization_dimension(result) 可能不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand)
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)

範例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

更多範例

反向排序

語義學

沿著指定的 dimensions 反轉 operand 中元素的順序,並產生 result 張量。更正式的說法是 result[result_index] = operand[operand_index],其中:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 如果 ddimensions 中。
  • 否則傳回 operand_index[d] = result_index[d]

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1)、(C3)
(I2) dimensions 1 維 si64 類型張量常數 (C2)、(C3)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)、(C3)

限制

  • (C1) type(operand) = type(result)
  • (C2) is_unique(dimensions)
  • (C3) 0 <= dimensions < rank(result)

範例

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

更多範例

恩格

語義學

使用 rng_distribution 演算法產生隨機號碼,並產生具有指定形狀 shaperesult 張量。

如果 rng_distribution = UNIFORM,則系統會根據 [a, b) 間隔的均勻分布方式產生隨機數字。如果為 a >= b,則行為未定義。

如果設為 rng_distribution = NORMAL,則系統會在常態分佈後產生隨機數字,平均值 = a,標準差 = b。如果是 b < 0,表示未定義行為。

隨機數字的產生方式由實作方式定義。舉例來說,這些事件可能會或不會是決定性的,也可能會或不會使用隱藏狀態。

在與許多利益相關者進行對談後,我們發現這個作業已實際淘汰,因此未來我們將著手移除這個作業 (#597)。

輸入

標籤 名稱 類型 限制
(I1) a 整數、布林值或浮點值類型的 0 維張量 (C1)、(C2)
(I2) b 整數、布林值或浮點值類型的 0 維張量 (C1)、(C2)
(I3) shape 1 維 si64 類型張量常數 (C3)
(I4) rng_distribution UNIFORMNORMAL 的列舉 (C2)

輸出內容

名稱 類型 限制
result 整數、布林值或浮點類型的張量 (C1-C3)

限制

  • (C1) element_type(a) = element_type(b) = element_type(result)
  • (C2) 如果 rng_distribution = NORMAL,則 is_float(a)
  • (C3) shape(result) = shape

範例

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

語義學

在給定初始狀態 initial_state 的情況下,使用虛擬亂數產生器演算法 rng_algorithm 傳回填入均勻隨機位元組的 output,以及更新的輸出狀態 output_state。輸出內容保證為 initial_state 的確定性函式,但不保證在實作之間是確定性的。

rng_algorithm 是下列其中一項:

  • DEFAULT:實作定義的演算法。
  • THREE_FRY:Threefry 演算法的實作定義變化版本*。
  • PHILOX:Philox 演算法的實作定義變化版本*。

* 參見:Salmon et al. SC 2011. 平行隨機號碼:就像 1、2、3 一樣簡單。

輸入

標籤 名稱 類型 限制
(I1) rng_algorithm DEFAULTTHREE_FRYPHILOX 的列舉 (C2)
(I2)。 initial_state 1 維的 ui64 類型張量 (C1)、(C2)

輸出內容

名稱 類型 限制
output_state 1 維的 ui64 類型張量 (C1)
output 整數或浮點類型的張量

限制

  • (C1) type(initial_state) = type(output_state)
  • (C2) size(initial_state) 的定義為:
    • 如果 rng_algorithm = DEFAULT,則會定義實作項目。
    • rng_algorithm = THREE_FRY 則設為 2
    • 如果是 rng_algorithm = PHILOX,則為 23

範例

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

語義學

operand 張量上執行元素逐項四捨五入至最接近的整數,並產生 result 張量。實作 IEEE-754 規格中的 roundToIntegralTiesToAway 運算。如果是量化類型,請執行 dequantize_op_quantize(round_nearest_afz, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點類型的張量或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

更多範例

round_nearest_even

語義學

operand 張量上執行元素逐項四捨五入至最接近的整數,並將繫結值切斷至偶數整數,並產生 result 張量。實作 IEEE-754 規格中的 roundToIntegralTiesToEven 運算。針對量化類型,執行 dequantize_op_quantize(round_nearest_even, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點類型的張量或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點類型的張量或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

更多範例

rsqrt

語義學

operand 張量執行元素級別的倒數平方根運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 rSqrt
  • 複數:複數的倒數平方根。
  • 量化類型:dequantize_op_quantize(rsqrt, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

更多範例

散布

語義學

產生的 results 張量等同於 inputs 張量,但 scatter_indices 指定的多個切片會使用 update_computation 的值 updates 進行更新。

以下圖表以具體範例說明 updates... 中的元素如何對應至 results... 中的元素。此圖表會挑選幾個 updates... 索引範例,並詳細說明這些索引對應的 results... 索引。

散布

更正式的說法是,對於 index_space(updates[0]) 中的所有 update_index

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
  • update_scatter_index = update_index[update_scatter_dims...]
  • start_index 的定義如下:
    • scatter_indices[si0, ..., :, ..., siN],其中 siupdate_scatter_index 中的個別元素,如果 index_vector_dim < rank(scatter_indices): 會插入 index_vector_dim 索引。
    • 否則為 [scatter_indices[update_scatter_index]]
  • axes(inputs[0]) 中的 d_input
    • d_input = scatter_dims_to_operand_dims[d_start] 則設為 full_start_index[d_input] = start_index[d_start]
    • 否則傳回 full_start_index[d_input] = 0
  • axes(inputs[0]) 中的 d_input
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)],如果 d_input = input_batching_dims[i_batching]d_start = scatter_indices_batching_dims[i_batching]
    • 否則為 full_batching_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...]
  • full_window_index = [wi0, ..., 0, ..., wiN],其中 wiupdate_window_index 中的個別元素,而 0 會插入來自 inserted_window_dimsinput_batching_dims 的索引。
  • result_index = full_start_index + full_batching_index + full_window_index

因此,results = exec(schedule, inputs),其中:

  • scheduleindex_space(updates[0]) 的實作定義排列組合。
  • exec([update_index, ...], results) = exec([...], updated_results) where:
    • 如果 result_index 位於 shape(results...) 的邊界內
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_resultsresults 的副本,results...[result_index] 已設為 updated_values...
    • 你也可以
    • updated_results = results
  • exec([], results) = results

如果 indices_are_sortedtrue,則實作程序可以假設 scatter_indices 會依 scatter_dims_to_operand_dims 排序,否則行為未定義。更正式地說,對於 indices(result) 中的所有 i1 < i2full_start_index(i1) <= full_start_index(i2)

如果 unique_indicestrue,實作可以假設所有 result_index 索引都是不重複的。如果 unique_indicestrue,但被分散的索引重複,就不會定義該行為。

輸入

標籤 名稱 類型 限制
(I1) inputs 變異量或每個張量量化張量 (C1)、(C2)、(C4-C6)、(C11)、(C13)、(C18)、(C21)、(C23-C24)
(I2)。 scatter_indices 整數類型的張量 (C4)、(C15)、(C19)、(C22)
(I3) updates 變異量或每個張量量化張量 (C3-C6)、(C8)
(I4) update_window_dims 1 維 si64 類型張量常數 (C2)、(C4)、(C7-C8)
(I5) inserted_window_dims 1 維 si64 類型張量常數 (C2)、(C4)、(C9-C11)
(I6) input_batching_dims 1 維 si64 類型張量常數 (C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20)
(I7) scatter_indices_batching_dims si64 類型的 1D 張量常數 (C14-C18)
(I8)。 scatter_dims_to_operand_dims 1 維 si64 類型張量常數 (C19-C21)
(I9) index_vector_dim 類型為 si64 的常數 (C4)、(C16)、(C19)、(C22)
(I10) indices_are_sorted 類型為 i1 的常數
(I11) unique_indices 類型為 i1 的常數
(I12) update_computation 函式 (C23)

輸出內容

名稱 類型 限制
results 可變數張量數量或每個張量量化張量 (C24-C25)

限制

  • (C1) same(shape(inputs...))
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`。
  • (C3) same(shape(updates...))
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) where:
    • update_scatter_dim_sizes = shape(scatter_indices),但不包括與 index_vector_dim 對應的 scatter_indices 尺寸大小。
    • update_window_dim_sizes <= shape(inputs[0]),但不包含 inputs[0] 中對應 inserted_window_dimsinput_batching_dims 的維度大小。
    • combine 會將 update_scatter_dim_sizes 放在對應 update_scatter_dims 的軸上,並將 update_window_dim_sizes 放在對應 update_window_dims 的軸上。
  • (C5) 0 < size(inputs) = size(updates) = N
  • (C6) element_type(updates...) = element_type(inputs...)
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8) 0 <= update_window_dims < rank(updates[0])
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims)
  • (C11) 0 <= inserted_window_dims < rank(inputs[0])
  • (C12) is_sorted(input_batching_dims)
  • (C13) 0 <= input_batching_dims < rank(inputs[0]))
  • (C14) is_unique(scatter_indices_batching_dims)
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices)
  • (C16) index_vector_dim not in scatter_indices_batching_dims
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims)
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices)
  • (C23) update_computation 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C24) shape(inputs...) = shape(results...)
  • (C25) [0,N) 中所有 ielement_type(results[i]) = Ei

範例

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

更多範例

選取

語義學

產生 result 張量,其中每個元素皆根據 pred 對應元素的值,從 on_trueon_false 張量中選取。更正式的說法是 result[result_index] = pred_element ? on_true[result_index] : on_false[result_index],其中 pred_element = rank(pred) = 0 ? pred[] : pred[result_index]。針對量化類型,執行 dequantize_select_quantize(pred, on_true, on_false, type(result))

輸入

標籤 名稱 類型 限制
(I1)。 pred i1 類型的張量 (C1)。
(I2) on_true 張量或每個張量量化張量 (C1-C2)
(I3) on_false 張量或每個張量量化張量 (C2)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C2)

限制

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

範例

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

更多範例

select_and_scatter

語義學

根據 select 使用 input 張量的 reduce_window 結果,使用 scattersource 張量散布值,並產生 result 張量。

下圖以具體範例說明如何從 operandsource 計算 result 中的元素。

select_and_scatter

更正式:

  • selected_values = reduce_window_without_init(...) 替換為下列輸入內容:

    • inputs = [operand].
    • window_dimensionswindow_stridespadding,可直接使用。
    • base_dilations = windows_dilations = 1
    • body 定義為:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    其中 E = element_type(operand)reduce_window_without_init 的運作方式與 reduce_window 完全相同,但基礎 reduceschedule (請參閱「縮減」) 不包含 init 值。目前如果對應的視窗沒有值,系統會未指定會發生什麼事 (#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) where:

    • source_values = [source[source_index] for source_index in source_indices]
    • selected_index(source_index) = operand_index:如果 selected_values[source_index] 包含 operand_index 中的 operand 元素。
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1-C4)、(C6)、(C8-C11)
(I2) source 張量或每個張量量化張量 (C1)、(C2)
(I3) init_value 0 維張量或每個張量量化張量 (C3)
(I4)。 window_dimensions 1 維 si64 類型張量常數 (C2)、(C4)、(C5)
(I5)。 window_strides 1 維 si64 類型張量常數 (C2)、(C6)、(C7)
(I6)。 padding 類型為 si64 的 2 維張量常數 (C2)、(C8)
(I7)。 select 函式 (C9)
(I8) scatter 函式 (C10)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C11-C12)

限制

  • (C1) element_type(operand) = element_type(source)
  • (C2) shape(source) = num_windows 其中:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
  • (C3) element_type(init_value) = element_type(operand)
  • (C4) size(window_dimensions) = rank(operand)
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(operand)
  • (C7) 0 < window_strides
  • (C8) shape(padding) = [rank(operand), 2]
  • (C9) select 具有 (tensor<E>, tensor<E>) -> tensor<i1> 類型,其中 E = element_type(operand)
  • (C10) scatter 具有 (tensor<E>, tensor<E>) -> tensor<E> 類型,其中 is_promotable(element_type(operand), E)
  • (C11) shape(operand) = shape(result)
  • (C12) element_type(result) = E.

範例

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

更多範例

傳送

語義學

inputs 傳送至頻道 channel_id,並產生 result 權杖。

如果 is_host_transfertrue,則作業會將資料傳輸至主機。否則,系統會將資料傳輸到其他裝置。何謂實作定義。這個旗標會重複 channel_type 中提供的資訊,因此我們日後預計只保留其中一個 (#666)。

輸入

標籤 名稱 類型 限制
(I1) inputs 張量或量化張量的變數參數數量
(I2) token token
(I3) channel_id 類型為 si64 的常數
(I4) channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST 的列舉 (C1)
(I5) is_host_transfer i1 類型的常數 (C1)

輸出內容

名稱 類型
result token

限制

  • (C1) channel_type 的定義如下:
    • DEVICE_TO_HOSTis_host_transfer = true
    • 否則為 DEVICE_TO_DEVICE

範例

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多範例

shift_left

語義學

lhs 張按 rhs 個位元執行元素的左移作業,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) lhs 整數類型的張量 (C1)
(I2) rhs 整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數類型的張量 (C1)

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

更多範例

shift_right_arithmetic

語義學

根據 rhs 位元數量,對 lhs 張量執行元素逐項算術右移運算,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) lhs 整數類型的張量 (C1)
(I2) rhs 整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數類型的張量 (C1)

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

更多範例

shift_right_logical

語義學

根據 rhs 位元數量,對 lhs 張量執行元素邏輯右移運算,並產生 result 張量。

輸入

標籤 名稱 類型 限制
(I1) lhs 整數類型的張量 (C1)
(I2) rhs 整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 整數類型的張量 (C1)

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

更多範例

簽署

語義學

依元素傳回 operand 的符號,並產生 result 張量。更正式地說,對於每個元素 x,語意可使用 Python 語法表達,如下所示:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

針對量化類型,執行 dequantize_op_quantize(sign, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 有符號整數、浮點或複數型別的張量,或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 帶正負號整數、浮點數或複雜類型或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

更多範例

正弦

語義學

operand 張量執行元素級別的正弦運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 sin
  • 複數:複數正弦。
  • 對於量化類型:dequantize_op_quantize(sine, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

更多範例

配量

語義學

使用靜態計算的起始索引,從 operand 擷取切片,並產生 result 張量。start_indices 包含每個維度切片的起始索引,limit_indices 包含每個維度切片的結束索引 (不含),而 strides 則包含每個維度的步幅。

更正式的說法是 result[result_index] = operand[operand_index],其中 operand_index = start_indices + result_index * strides

輸入

標籤 名稱 類型 限制
(I1) operand 張量或每個張量量化張量 (C1-C3)、(C5)
(I2) start_indices 1 維 si64 類型張量常數 (C2)、(C3)、(C5)
(I3)。 limit_indices 1 維 si64 類型張量常數 (C2)、(C3)、(C5)
(I4) strides 1 維 si64 類型張量常數 (C2)、(C4)

輸出內容

名稱 類型 限制
result 張量或每個張量量化張量 (C1)、(C5)

限制

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand)
  • (C4) 0 < strides
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides)

範例

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

更多範例

排序

語義學

依據 comparator 排序 inputs 的 1 維切片,並沿著維度 dimension 一起排序,產生 results

與其他作業中的類似輸入不同,dimension 允許使用負數,其語意如下所述。不過,基於一致性考量,我們日後可能會禁止這類行為 (#1377)。

如果 is_stable 為 true,排序作業就會穩定,也就是說,比較器會保留視為相等的元素相對順序。對於只有單一輸入值的情況,比較器會將兩個元素 e1e2 視為相等,前提是 comparator(e1, e2) = comparator(e2, e1) = false 必須為 請參閱下方的正式化,瞭解如何以多種方式將內容一般化為多個輸入內容。

更正式的說法是,對於 index_space(results[0]) 中的所有 result_index

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
  • result_slice = [ri0, ..., :, ..., riR-1],其中 riNresult_index 中的個別元素,: 則是在 adjusted_dimension 插入。
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...)
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
  • 其中 sort 會以非遞減順序排序 1 維切片,並預期如果左側引數小於右側第二個引數,comparator_together 會傳回 true
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together

輸入

標籤 名稱 類型 限制
(I1) inputs 可變數張量數量或每個張量量化張量 (C1-C5)
(I2) dimension 類型為 si64 的常數 (C4)
(I3) is_stable 類型為 i1 的常數
(I4) comparator 函式 (C5)

輸出內容

名稱 類型 限制
results 變異量或每個張量量化張量 (C2)、(C3)

限制

  • (C1) 0 < size(inputs)
  • (C2) type(inputs...) = type(results...)
  • (C3) same(shape(inputs...) + shape(results...))
  • (C4) -R <= dimension < R,其中 R = rank(inputs[0])
  • (C5) comparator 採用 (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1> 類型,其中 Ei = element_type(inputs[i])

範例

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

更多範例

sqrt

語義學

operand 張量執行元素級別平方根運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 squareRoot
  • 複數:複數的平方根。
  • 對於量化類型:dequantize_op_quantize(sqrt, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

更多範例

subtract

語義學

執行兩個張量相減 lhsrhs 的元素,產生 result 張量。根據元素類型執行以下操作:

  • 整數:整數減法。
  • 浮點值:IEEE-754 的 subtraction
  • 複數:複數減法。
  • 量化類型:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result))

輸入

標籤 名稱 類型 限制
(I1) lhs 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)。
(I2) rhs 整數張量、浮點數或複雜類型,或每個張量的量化張量 (C1)。

輸出內容

名稱 類型 限制
result 整數、浮點或複數類型的張量,或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

更多範例

tan

語義學

operand 張量執行元素級別切線運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 tan
  • 複數:複數正切。
  • 對於量化類型:dequantize_op_quantize(tan, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

更多範例

tanh

語義學

operand 張量執行元素級別的雙曲正切運算,並產生 result 張量。視元素類型而定,執行下列操作:

  • 浮點值:IEEE-754 的 tanh
  • 複數:複數雙曲正切。
  • 針對量化型別:
    • dequantize_op_quantize(tanh, operand, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或複雜型別或每個張量的量化張量 (C1)

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

更多範例

轉置

語義學

使用 permutationoperand 張量縮小,並產生 result 張量。更正式的說法是 result[result_index] = operand[operand_index],其中 result_index[d] = operand_index[permutation[d]]

輸入

標籤 名稱 類型 限制
(I1) operand 張量或量化張量 (C1-C4)
(I2) permutation si64 類型的 1D 張量常數 (C2-C4)

輸出內容

名稱 類型 限制
result 張量或量化張量 (C1)、(C3-C4)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand),如果 !is_per_axis_quantized(operand)
    • element_type(operand),但 quantization_dimension(operand)quantization_dimension(result) 可能不同。
  • (C2) permutationrange(rank(operand)) 的排列。
  • (C3) shape(result) = dim(operand, permutation...)
  • (C4) 如果 is_per_axis_quantized(result),則 quantization_dimension(operand) = permutation(quantization_dimension(result))

範例

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

更多範例

triangular_solve

語義學

使用上三角或下三角係數矩陣,解開批次的線性聯立方程式。

更正式的說法,就 ab 而言,當 left_sidetruex * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] 時為 x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] 時是 op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] 的解決方案,其中 op(a)transpose_a 決定為 x 的解決方式,其中 op(a) 為下列其中 transpose_aresult[i0, ..., iR-3, :, :]left_sidefalse

  • NO_TRANSPOSE:依原樣使用 a 執行操作。
  • TRANSPOSE:對 a 轉置作業執行。
  • ADJOINT:對 a 的共軛轉置作業執行作業。

如果 lowertrue,則輸入資料只會從 a 的下三角讀取;如果是 a 的上半三角,則輸入資料會從 a 的下三角讀取。輸出資料會在同一個三角形中傳回,其他三角形的值則是實作定義。

如果 unit_diagonal 為 true,實作作業可以假設 a 的對角元素等於 1,否則行為未定義。

針對量化類型,執行 dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))

輸入

標籤 名稱 類型 限制
(I1) a 浮點或複雜型別或每個張量的量化張量 (C1-C3)
(I2) b 浮點或複雜型別或每個張量的量化張量 (C1-C4)
(I3) left_side i1 類型的常數 (C3)。
(I4) lower i1 類型的常數
(I5) unit_diagonal 類型為 i1 的常數
(I6) transpose_a NO_TRANSPOSETRANSPOSEADJOINT 的列舉

輸出內容

名稱 類型 限制
result 浮點或複雜型別或每個張量的量化張量 (C1)

限制

  • (C1) baseline_element_type(a) = baseline_element_type(b)
  • (C2) 2 <= rank(a) = rank(b) = R
  • (C3) shape(a)shape(b) 之間的關係定義如下:
    • shape(a)[:-3] = shape(b)[:-3]
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
  • (C4) baseline_type(b) = baseline_type(result)

範例

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

元組

語義學

從值 val 產生 result 元組。

輸入

標籤 名稱 類型 限制
(I1) val 值的變化數 (C1)

輸出內容

名稱 類型 限制
result 元組 (C1)

限制

  • (C1) result 屬於 tuple<E0, ..., EN-1> 類型,其中 Ei = type(val[i])

範例

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

更多範例

uniform_dequantize

語義學

根據 operand 類型定義的量化參數,將量化張量 operand 逐元素轉換為浮點張量 result

更正式的說法是 result = dequantize(operand)

輸入

標籤 名稱 類型 限制
(I1)。 operand 量化張量 (C1)、(C2)

輸出內容

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(operand) = shape(result)
  • (C2) element_type(result) = expressed_type(operand)

範例

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

語義學

根據 result 類型定義的量化參數,對浮點張量或量化張量 operand 執行元素轉換,並轉換為量化張量 result

更正式的說法是:

  • 如果 is_float(operand)
    • result = quantize(operand, type(result))
  • 如果 is_quantized(operand)
    • float_result = dequantize(operand)
    • result = quantize(float_result, type(result))

輸入

標籤 名稱 類型 限制
(I1) operand 浮點或量化型別的張量 (C1)、(C2)

輸出內容

名稱 類型 限制
result 量化張量 (C1)、(C2)

限制

  • (C1) shape(operand) = shape(result)
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

範例

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

while

語義學

cond 函式輸出 true 的同時,產生執行 body 函式 0 次以上所產生的輸出內容。更正式地說,語意可使用 Python 語法表達,如下所示:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

無限迴圈的行為為 TBD (#383)。

輸入

標籤 名稱 類型 限制
(I1) operand 張量、量化張量或符記的變數參數數量 (C1-C3)
(I2)。 cond 函式 (C1)
(I3) body 函式 (C2)

輸出內容

名稱 類型 限制
results 變異量、量化張量或代詞 (C3)

限制

  • (C1) cond 屬於 (T0, ..., TN-1) -> tensor<i1> 類型,其中 Ti = type(operand[i])
  • (C2) body 屬於 (T0, ..., TN-1) -> (T0, ..., TN-1) 類型,其中 Ti = type(operand[i])
  • (C3) type(results...) = type(operand...)

範例

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

更多範例

xor

語義學

對兩個張量 lhsrhs 執行元素 XOR,並產生 result 張量。視元素類型而定,執行下列操作:

  • 布林值:邏輯 XOR。
  • 整數:位元 XOR。

輸入

標籤 名稱 類型 限制
(I1)。 lhs 布林值或整數類型的張量 (C1)
(I2) rhs 布林值或整數類型的張量 (C1)

輸出內容

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

更多範例

方言互通性

目前,實際環境中的 StableHLO 程式有時會包含非 StableHLO 定義的作業。

模組、函式、呼叫和傳回

StableHLO 針對 ModuleOp、FuncOp、CallOp 和 ReturnOp 使用上游 MLIR 作業。這麼做是為了與現有的 MLIR 機制進行更佳的互通作業,因為許多實用的傳遞都是以 FuncOp 和 ModuleOp 為目標而編寫,而許多編譯管道都會預期這些運算子會出現。這些作業會套用完整相容性保證。如果這些作業的變更方式不相容 (例如移除),系統會新增 StableHLO 等價項目,以維持相容性。

CHLO

CHLO 運算組合包含分解至 StableHLO 的層級較高的作業。我們目前並未針對 CHLO 提供相容性保證。針對相容性保證,序列化之前必須使用 chlo-legalize-to-stablehlo Pass

形狀運算

在社群中,動態 StableHLO 程式中使用核心 MLIR 方言的特定運算來執行形狀運算,是常見的用途。最常見的情況包括 shape 方言運算 (例如 shape_ofnum_elements)、tensor 方言運算 (例如 dimfrom_elements),以及內建的 index 類型。

動態 RFC > O2 將這些項目視為超出範圍,但為了互通性,我們仍納入了對 index 類型的部分支援。我們無法保證這些運算或類型具有相容性。shape-legalize-to-stablehlo 管道可用於將這些運算轉換為完全支援的 StableHLO 運算。

已淘汰的作業

有幾個 StableHLO 作業是繼承自 MHLO,這些作業已淘汰,並將從 StableHLO 中移除。如需這些移除作業的完整詳細資料,請參閱 StableHLO 第 1.0 版清理作業 #2283。這些淘汰作業的追蹤問題為 #2340

這些作業可分為以下幾類:

  • StableHLO 運算的「不在 HLO」類別 - 它們最初是 StableHLO 運算組合的一部分,但後來被認為不適合: broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum (#3)。
  • 未使用的運算 - 這些作業有時可能很實用,但作業未經開發,或是將這些作業的管道重構為不再需要這些作業。包括 maptuple (#598)、get_tuple_elementrngcomplex 比較 #560 和卷積 window_reversal (#1181)。

由於這些運算可使用現有運算 (broadcastcreate_tokencross-replica-sumdotunary_einsum) 表達,因此部分運算可輕易移除,並會在現有相容性時間窗口 (6 個月) 過後移除。其他運算子 (einsumget_tuple_elementmaprngtorch_index_selecttuplecomplex 比較、window_reversal) 仍在評估是否移除。我們會等待社群的意見回饋,再決定是否移除這些運算子,或是將這些運算子加入規格,以便提供完整支援。直到已知運算 Future 為止,僅保證 6 個月的相容性。

執行

依序執行

藉由提供 main 函式的輸入值並運算輸出值,即可執行 StableHLO 程式。系統會執行在對應的 return 運算根基運算圖,藉此計算函式的輸出值。

只要執行順序與資料流程一致 (也就是在使用前執行運算),執行順序就會由實作定義。在 StableHLO 中,所有副作用運算都會使用一個符記並產生一個符記 (多個符記可透過 after_all 多路復用為一個符記),因此副作用的執行順序也會與資料流保持一致。例如,在下列程式中,執行順序有兩種:%0%1%2return%1%0%2return

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

更正式的說法是,StableHLO 程序是由:1) StableHLO 程式、2) 作業狀態 (尚未執行、已經執行),以及 3) 程序正在執行的中繼值。這個程序會從 main 函式的輸入值開始,逐步透過更新作業狀態和中繼值的作業圖表,最後輸出值。進一步的規範化待定 (#484)。

平行執行

StableHLO 程式可並行執行,並以 num_replicasnum_partitions 的 2D 處理格狀排列,兩者皆為 ui32 類型。

StableHLO 程序格線中,StableHLO 程序會同時執行 num_replicas * num_partitions。每個程序都有專屬的 process_id = (replica_id, partition_id),其中 replica_ids = range(num_replicas)partition_ids = range(num_partitions) 中的 replica_id 都有 ui32 類型。partition_id

每個程式的程序格線大小會以靜態方式提供 (我們預計在未來將其設為 StableHLO 程式的明確部分 #650),而每個程序在程序格線中的位置也會以靜態方式提供。每個程序都可以透過 replica_idpartition_id 作業,存取程序格線中的位置。

在程序格線中,程序可以全部相同 (在「單一程序、多個資料」樣式中)、全部不同 (在「多個程序、多個資料」樣式中),或介於兩者之間。日後,我們預計會推出支援定義平行 StableHLO 程式的其他慣用語,包括 GSPMD (#619)。

在程序網格中,程序多半彼此獨立,它們具有不同的作業狀態、獨立的輸入/中繼/輸出值,且大多數作業會在程序之間分別執行,但以下所述的少數集體運算除外。

由於大多數運算的執行作業只使用相同程序的值,因此通常可以透過名稱參照這些值,不會造成混淆。不過,在描述集體運算的語意時,這並不足夠,且這有助於標記 name@process_id,以參照特定程序中的 name 值。(從這個角度來看,未經限定的 name 可視為 name@(replica_id(), partition_id()) 的簡寫)。

跨程序的執行順序是由實作程序定義,但點對點通訊和集體作業所引進的同步處理例外,如以下所述。

點對點通訊

StableHLO 程序可透過 StableHLO 管道互相通訊。頻道由 si64 類型的正 ID 表示。透過各種作業,您可以將值傳送至管道,並從管道接收值。

進一步的規範化 (例如這些管道 ID 的來源、程序如何得知這些 ID,以及這些 ID 引進的同步處理) 尚未定案 (#484)。

串流通訊

每個 StableHLO 程序都能存取兩個串流介面:

  • 可讀取的Infeed
  • 可寫入的外部動態饋給

與管道不同,這類管道用於處理程序之間的通訊,因此兩端都有處理程序,而內部動態饋給和外部動態饋給則是透過其他端點實作。

進一步的規範化,例如串流通訊如何影響執行順序,以及這會引進哪種同步處理,則尚未確定 (#484)。

集體作業

StableHLO 中有六個集合運算:all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter。所有這些運算都會將 StableHLO 程序格狀圖中的程序分割為 StableHLO 程序群組,並在各程序群組中執行共同運算,不受其他程序群組影響。

在每個程序群組中,集體作業可能會引入同步化障礙。進一步的規範化,例如詳述這項同步作業的確切發生時間、程序如何抵達這個障礙,以及如果未抵達會發生什麼事,則是未定 (#484)。

如果程序群組涉及跨區別通訊,也就是程序群組中包含的程序區別 ID 不同,則集體作業的執行作業需要管道,且集體作業必須提供 si64 類型的正值 channel_id。跨複本通訊不需要管道。

集體運算執行的運算專屬於個別運算,並且會在上方個別運算章節中說明。不過,這些作業之間會共用將程序格狀圖分割為程序群組的策略,而這會在本節中說明。更正式地說,StableHLO 支援下列四種策略。

cross_replica

每個程序群組之間只會發生跨副本通訊。這個策略會採用 replica_groups (複本 ID 的清單清單),並計算 replica_groupspartition_ids 的笛卡兒積。replica_groups 必須包含專屬元素,並涵蓋所有 replica_ids。更正式地說,使用 Python 語法:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,針對 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica 會產生 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]

cross_partition

每個處理序群組內只會發生跨區塊通訊。此策略會採用 partition_groups (分割區 ID 的清單清單),並透過 replica_ids 計算 partition_groups 的笛卡兒積。partition_groups 必須包含專屬元素,且涵蓋所有 partition_ids。更正式地說,使用 Python 語法:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,針對 partition_groups = [[0, 1]]num_replicas = 4cross_partition 會產生 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]

cross_replica_and_partition

每個程序群組都可能發生跨副本和跨區塊的通訊。此策略會採用 replica_groups (複本 ID 清單清單),並透過 partition_ids 計算每個 replica_group 的笛卡兒積。replica_groups 必須包含不重複的元素,並涵蓋所有 replica_ids。更正式地說,使用 Python 語法:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

例如,針對 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica_and_partition 會產生 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]

flattened_ids

這項策略會採用 flattened_id_groups,也就是以 replica_id * num_partitions + partition_id 格式呈現的「扁平化」程序 ID 清單清單,並將這些項目轉換為程序 ID。flattened_id_groups 必須包含專屬元素,並涵蓋所有 process_ids。更正式地說,使用 Python 語法:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

例如,對於 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2flattened_ids 會產生 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]

準確率

目前,StableHLO 不保證數值準確度,但這項規定日後可能會有所變動 (#1156)。

量化運算的執行語意

量化 StableHLO 作業的解讀方式可能會因硬體需求和功能而異。舉例來說,某些硬體可能會選擇使用「解量化、執行浮點運算,最後再量化」策略來解讀量化運算。其他人則可能會使用整數算術執行整個運算。因此,量化的 StableHLO 作業的解讀方式,完全取決於特定實作方式。混合量化 (#1575) 的解讀方式應以規格中規定的語意為依據 (透過 1792)。

錯誤

StableHLO 程式會透過一組針對個別運算的廣泛限制進行驗證,藉此在執行時間之前排除許多類型的錯誤。不過,仍可能發生錯誤狀況,例如整數溢位、超出邊界存取等。除非明確指出,否則所有這些錯誤都會導致實作定義的行為,但這可能會在日後變更 (#1157)。

浮點例外狀況

這項規則的例外狀況是,StableHLO 程式中的浮點例外狀況具有明確的行為。導致 IEEE-754 標準定義例外狀況的作業 (無效作業、除以零、溢位、反向溢位或不精確的例外狀況) 會產生預設結果 (如標準定義) 並在不提高對應的狀態旗標的情況下繼續執行;與標準的 raiseNoFlag 例外狀況處理類似。非標準作業 (例如複雜算術和特定半透明函式) 的例外狀況是由實作定義。

形狀不符

StableHLO 支援動態形狀張量。不過,形狀必須在執行階段相符,否則行為會未定義。StableHLO 不會明確提供可斷言張量在執行階段具有特定形狀的運算。產生正確的程式碼是製作者的責任。

以下列程式為例,說明有效的程式。不過,在執行階段,%arg0%arg1 的確切形狀必須相同,否則程式的行為會不明確:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

為了說明語法,這份文件使用經過修改的 ISO 風格 EBNF 語法 (ISO/IEC 14977:1996維基百科),並進行兩項修改:1) 使用 ::= 而非 = 定義規則。

2) 串接是使用並列,而非 ,

為了說明語意 (即「類型」、「常數」和「運算」部分),我們使用了以 Python 語法為基礎的公式,並擴充支援功能,以便簡潔地表達陣列運算,如下所述。這麼做很適合用於小型程式碼片段,但在少數需要較大程式碼片段的情況下,我們會使用純 Python 語法,這類語法一律會明確引入。

公式

接下來,讓我們根據 dot_general 規格中的範例,探索公式的運作方式。這項作業的其中一個限制如下所示:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

這個公式中使用的名稱來自兩個來源:1) 全域函式,例如 dim;2) 對應程式元素的會員定義,例如 dot_general 的「Inputs」部分中定義的 lhslhs_batching_dimensionsrhsrhs_batching_dimensions 輸入內容。

如上所述,這個公式的語法是以 Python 為基礎,並加入一些簡潔的擴充功能。為了瞭解公式,我們將轉換成一般 Python 語法。

A) 在這些公式中,我們使用 = 代表相等,因此取得 Python 語法的第一步驟,就是將 = 替換為 ==,如下所示:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) 這些公式也支援橢圓 (...),可將標量運算式轉換為張量運算式。簡而言之,f(xs...) 大致是指「針對張量 xs 中的每個標量 x,計算標量 f(x),然後將所有標量結果一起傳回做為張量結果」。在一般 Python 語法中,我們的範例公式會變成:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

多虧了省略號,我們通常可以避免在個別標量層級運作。不過,在某些棘手的情況下,可能會使用較低層級的半正式語法,例如 gather 規格中的 start_indices[bi0, ..., :, ..., biN] 公式。為了簡潔起見,我們不會提供將這類語法轉譯為一般 Python 的確切形式,希望您能根據個別情況,直覺地瞭解這類語法。如果某些特定公式看起來不透明,請告訴我們,我們會盡力改善。

此外,您會發現公式會使用橢圓形展開所有類型的清單,包括張量、張量清單 (例如可由變數張量數量產生) 等。這是我們無法提供確切形式主義的另一個領域 (例如清單甚至不是 StableHLO 類型系統的一部分),而是依賴直覺易懂的形式。

C) 最後一個值得注意的符號工具是隱含廣播。雖然 StableHLO 運算集不支援隱式廣播,但公式也能使用精簡服務。簡單來說,如果在預期張量的情況下使用純量,純量將會廣播至預期的形狀。

接著繼續說明 dot_general 的例子,以下是另一個限制:0 <= lhs_batching_dimensions < rank(lhs)。如 dot_general 規格所定義,lhs_batching_dimensions 是張量,但 0rank(lhs) 都是純量。套用隱含廣播後,公式會變成 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]

套用於特定 dot_general 運算時,這個公式會評估為布林值的張量。當公式用於限制條件時,如果公式評估結果為 true,或只有 true 元素的張量,則限制條件會生效。

名稱

在公式中,字彙範圍包括:1) 全域函式、2) 成員定義、

3) 本機定義。以下提供全域函式清單。元素定義清單取決於該符號要套用的程式元素:

  • 對於運算,成員定義包含「輸入」和「輸出」部分中引入的名稱。
  • 對於其他所有內容,成員定義包含程式元素的結構部分,以相應的 EBNF 非終端符命名。在大多數情況下,這些結構性部分的名稱會透過將非終端符號的名稱轉換為蛇形命名法 (例如 IntegerLiteral => integer_literal) 來取得,但有時名稱會在過程中縮寫 (例如 QuantizationStorageType => storage_type),在這種情況下,名稱會以類似於操作規格中「輸入」/「輸出」部分的方式明確引入。
  • 此外,成員定義一律會包含 self,用於參照對應的程式元素。

評估公式時,公式會使用下列類型的值:1) Value (實際值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;這些值一律會知道其類型)、2) Placeholder (未來值,例如 lhsrhsresult;這些值的實際值尚未知曉,只知道其類型)、3) Type (在「類型」一節中定義的類型)、4) Function (在「函式」一節中定義的全域函式)。

依情境而定,名稱可能會參照不同的值。更具體來說,運算的「語意」區段 (以及其他程式元素的對等項目) 定義了執行階段邏輯,因此所有輸入內容均以 Value 形式提供。相反地,運算 (和等價項目) 的「限制」部分會定義「編譯時間」邏輯,也就是通常在執行階段前執行的項目,因此只有常數輸入可做為 Value,其他輸入則只能做為 Placeholder

名稱 在「語意」中 在「限制條件」中
全域函式 Function Function
常數輸入 Value Value
非常數輸入 Value Placeholder
輸出內容 Value Placeholder
本機定義 取決於定義 取決於定義

假設 transpose 作業範例:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

對於這項作業,permutation 是常數,因此在語意和限制中都會以 Value 的形式提供。相反地,operandresult 可做為語意中的 Value,但僅限於限制中的 Placeholder

函式

類型的建構

沒有任何函式可用來建構類型。我們會直接使用類型語法,因為這通常更精簡。例如:(tensor<E>, tensor<E>) -> (tensor<E>) 而非 function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])

類型的函式

  • element_type 是在張量類型和量化張量類型上定義,並分別傳回對應 TensorTypeQuantizedTensorTypeTensorElementTypeQuantizedTensorElementType 部分。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is not None 的捷徑。

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is None 的捷徑。

  • is_promotable(x: Type, y: Type) -> bool 會檢查 x 類型是否可升級為 y 類型。如果 xyQuantizedTensorElementType,則促銷活動只會套用至 storage_type。這個特定版本的宣傳活動目前用於減少運算的情況 (詳情請參閱 RFC)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Valueis_quantized_tensor_element_type(x) 的捷徑。

  • is_type_name(x: Value | Placeholder | Type) -> Value。適用於所有類型。舉例來說,如果 xFloatTypeis_float(x) 就會傳回 true。如果 x 是值或預留位置,這個函式就是 is_type_name(type(x)) 的捷徑。

  • max_value(x: Type) -> Value 會傳回 TensorElementType 的最大值。如果 x 不是 TensorElementType,則會傳回 None

  • min_value(x: Type) -> Value 會傳回 TensorElementType 的最小可能值。如果 x 不是 TensorElementType,就會傳回 None

  • member_name(x: Value | Placeholder | Type) -> Any。適用於所有類型的所有成員定義 member_name。舉例來說,tensor_element_type(x) 會傳回對應 TensorTypeTensorElementType 部分。如果 x 是值或預留位置,這個函式就是 member_name(type(x)) 的捷徑。如果 x 不是具有適當成員的類型,或是該類型的值或預留位置,就會傳回 None

  • is_empty_algorithm(*args: Type) 會檢查所有點狀演算法欄位是否設為 None。這點需要這麼做,因為點點演算法已實作定義預設行為,因此指定預設值可能不正確。

值的建構

  • operation_name(*xs: Value | Type) -> Value:適用於所有作業。舉例來說,add(lhs, rhs) 會使用兩個張量值 lhsrhs,並使用這些輸入內容傳回評估 add 運算的輸出內容。對於某些運算 (例如 broadcast_in_dim),其輸出類型是「負載」,也就是需要評估運算。在這種情況下,函式會將這些類型做為引數。

值函式

  • 所有 Python 的運算子和函式都可供使用。例如,Python 中的訂閱切片符號都可用於索引張量、量化張量和元組。

  • to_destination_type(x: Value, destination_type: Type) -> Value 是在張量上定義,並會根據 type(x)destination_type 傳回 x 的轉換值,如下所示:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

我們曾討論過合併 convertuniform_quantizeuniform_dequantize 作業 (#1576)。合併後,我們不需要上述函式,可以改用 convert 的運算名稱。

  • is_nan(x: Value) -> Value 是在張量上定義,如果 x 的所有元素都是 NaN,則會傳回 true,否則會傳回 false。如果 x 不是張量,則會傳回 None

  • is_sorted(x: Value) -> Value 定義在張量上,如果 x 的元素按照遞增順序排序,則根據索引的遞增順序排序,否則會傳回 truefalse如果 x 不是張量,就會傳回 None

  • is_unique(x: Value) -> Value 是在張量上定義,如果 x 沒有重複的元素,則會傳回 true;如果有重複的元素,則會傳回 false。如果 x 不是張量,則會傳回 None

  • member_name(x: Value) -> Any 是所有值的所有成員定義 member_name 的定義。例如,real_part(x) 會傳回對應 ComplexConstantRealPart 部分。如果 x 不是具有適當成員的值,則會傳回 None

  • same(x: Value) -> Value 是在張量上定義,如果 x 的元素都相等,則會傳回 true;否則會傳回 false。如果張量沒有元素,則視為「所有元素都相等」,也就是函式會傳回 true。如果 x 不是張量,就會傳回 None

  • split(x: Value, num_results: Value, axis: Value) -> Value 是在張量上定義,並且會沿著 axis 軸傳回 xnum_results 配量。如果 x 不是張量或 dim(x, axis) % num_results != 0,則會傳回 None

  • is_defined_in_parent_scope(x: Value) -> Value 是在字串上定義,如果 x 是與相關運算子的父函式在相同範圍內定義的函式名稱,則會傳回 true

  • is_namespaced_op_name(x: Value) -> Value 是在字串上定義,如果 x 是有效的 op 名稱,則會傳回 true,也就是會遵循下列規則運算式:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

形狀運算

  • axes(x: Value | Placeholder | Type) -> Valuerange(rank(x)) 的捷徑。

  • dim(x: Value | Placeholder | Type, axis: Value) -> Valueshape(x)[axis] 的捷徑。

  • dims(x: Value | Placeholder | Type, axes: List) -> Listlist(map(lambda axis: dim(x, axis), axes)) 的捷徑。

  • index_space(x: Value | Placeholder | Type) -> Value 已在張量上定義,然後針對對應的 TensorType 傳回 size(x) 索引,這些 TensorType 會按照字母順序排列,也就是 [0, ..., 0][0, ..., 1]、...、shape(x) - 1。如果 x 不是張量類型、量化張量類型,或是上述其中一種類型的值或預留位置,系統會傳回 None

  • rank(x: Value | Placeholder | Type) -> Valuesize(shape(x)) 的捷徑。

  • shape(x: Value | Placeholder | Type) -> Value 是在「類型函式」一節中透過 member_name 定義。

  • size(x: Value | Placeholder | Type) -> Valuereduce(lambda x, y: x * y, shape(x)) 的捷徑。

量化運算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Typeelement_type(baseline_type(x)) 的捷徑。

  • baseline_type 會在張量類型和量化張量類型上定義,並將這些類型轉換為「基準」,也就是形狀相同但元素類型量化參數已重設為預設值的類型。這項功能可用來比較張量和量化張量類型,這項功能相當實用,以量化類型來說,這麼做可讓比較類型忽略量化參數,也就是 shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension (針對每個軸的量化類型) 必須全部相符,但 scaleszero points 可能會不同。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize 會在量化張量類型上定義,並將其轉換為浮點張量類型。這會透過使用與量化元素類型相關聯的零點和比例,將代表儲存類型整數值的量化元素轉換為表示類型的對應浮點值。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize 是在浮點張量類型上定義,並轉換為量化張量類型。這會透過使用與量化元素類型相關聯的零點和比例,將表示型別的浮點值轉換為儲存型別的對應整數值。
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize 可用於在量化張量上指定元素運算。系統會將量化元素 (例如將量化元素轉換為其代表的類型),然後執行作業,然後量化,例如將結果轉換回其儲存空間類型。目前,這項功能僅適用於每個張量量化。我們正在進行每個軸向量化作業 (#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op 用於指定混合運算的權重式量化,該運算可接受浮點型左值和量化型右值。它會將量化輸入值解量化為其表示型別,並以浮點值執行運算。浮點型左側張量元素類型和量化型右側張量表示類型應相同。
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

格線運算

  • cross_partition(replica_groups: Value) -> Value。請參閱上方的「cross_replica」一節。

  • cross_replica(replica_groups: Value) -> Value。請參閱上方的「cross_replica」一節。

  • cross_replica_and_partition(replica_groups: Value) -> Value。請參閱上方的「cross_replica_and_partition」一節。

  • flattened_ids(replica_groups: Value) -> Value。請參閱上方的「flattened_ids」一節。

動態

StableHLO 值可包含動態維度大小,例如 tensor<?xi64>。不過,StableHLO 值不得包含動態數量的維度 (例如 tensor<*xi64>)。即使大小設有限制,運算元和結果仍可使用動態維度大小。系統會盡可能以靜態方式驗證限制條件,否則會將限制條件延後至執行階段,不相符的情況會導致未定義的行為。請查看以下範例。

單項元素運算的形狀不符

請考慮以下玩具程式:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

這種程式很不尋常,因為通常不會知道結果的形狀,但輸入的形狀卻是已知的。儘管如此,這是有效的 StableHLO 程式。由於運算元的確切形狀不明,因此無法對這個程式中的 abs 運算進行靜態驗證。不過,這些形狀確實相容,而且可以進行靜態檢查:? 在執行階段可能會變成 2,但不會有任何問題。不過,? 也可能會變成其他整數,在這種情況下,行為就無法定義。

請注意,如果結果中的維度大小為動態大小,則不會出現未定義的行為。事實上,並沒有「預期」的大小,因此不會出現不相符的情況。

二元元素運算的形狀不符

請考慮使用下列玩具計畫:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

就二元元素運算而言,輸入內容和結果的形狀必須在執行階段一致。在編譯期間,靜態維度必須相等,否則只需要相容即可。如果輸入內容中的任一維度為動態,則在執行階段可能會出現未定義的行為,因為動態大小可能與其他運算元 (不論是靜態或動態) 中的對應大小不符。如果所有輸入內容都是靜態的,則結果是否為動態並不重要:靜態已知維度會以靜態方式檢查,而動態維度不會強制任何限制。

對於將輸出形狀做為運算元的運算子,形狀不相符

請考慮以下玩具程式:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

在執行階段,形狀運算子中的值必須與結果的形狀相符,否則行為將未定義。也就是說,在執行階段 %arg0 必須具有 dense<[3, 4]> : tensor<2xi32> 的值。如果形狀運算元是常數,則可進行靜態驗證。如果結果形狀是完全動態的,就不會出現不相符的情況。