StableHLO 規格

StableHLO 是機器學習 (ML) 模型中高階作業 (HLO) 的作業集。StableHLO 可做為不同機器學習架構和機器學習編譯器之間的可攜層:產生 StableHLO 程式的機器學習架構,與使用 StableHLO 程式的機器學習編譯器相容。

我們的目標是在不同的機器學習架構 (例如 TensorFlow、JAX 和 PyTorch) 與機器學習編譯器 (例如 XLA 和 IREE) 之間建立更互通性,藉此簡化及加快機器學習開發作業。因此,本文件提供 StableHLO 程式設計語言的規格。

這個規格包含三個主要部分首先,「程式」一節說明瞭 StableHLO 程式的結構,StableHLO 函式本身是由 StableHLO 運算組成。在這個結構中,「Ops」區段會指定個別運算的語意。「Execution」一節會針對在程式中一起執行的所有作業提供語意。最後,「標記法」一節將討論整個規格中使用的標記法。

程式

Program ::= {Func}

StableHLO 程式包含任意數量的 StableHLO 函式。以下是含有 @main 函式的範例程式,函式包含 3 個輸入 (%image%weights%bias) 和 1 個輸出內容。函式主體有 6 個運算。

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

函式

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 函式 (也稱為「已命名函式」) 具有 ID、輸入/輸出和主體。日後,我們計劃為函式推出其他中繼資料,以便提高與 HLO 的相容性 (#425#626#740#744)。

ID

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO ID 與許多程式設計語言中的 ID 相似,但基本上有兩大意義:1) 所有 ID 都有能夠區分不同種類的 ID,2) 值 ID 可以是完全數字,簡化 StableHLO 程式的產生作業。

類型

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 類型分為「值類型」 (也稱為「一類類型」),代表 StableHLO 值和描述其他程式元素的非值類型。StableHLO 類型與許多程式設計語言的類型類似,主要謹慎度是 StableHLO 的特定領域,導致某些不尋常的結果 (例如純量類型不是值類型)。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Tensor 類型代表張量,即多維陣列。它們具有形狀元素類型,其中形狀代表非負的維度大小,按照 0R-1 的對應維度 (也稱為「軸」) 遞增順序排列。維度數量 R 稱為「排名」。舉例來說,tensor<2x3xf32> 是形狀類型為 2x3 且元素類型為 f32 的張量類型。這項特徵有兩個維度 (或兩個軸,也就是兩個軸) - 第 0 個維度和 1 個維度,大小為 2 和 3。其排名為 2。

這會定義對靜態形狀的支援,其中尺寸是靜態的。未來,我們也計劃引進對動態形狀的支援,其中尺寸大小為部分或完全不明 (#8)。此外,我們也打算探索除了尺寸大小和元素類型以外的張量類型,例如納入版面配置 (#629) 和稀疏度 (#1078)。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
名稱 類型 限制
storage_type 整數類型 (C1-C4)、(C9)
storage_min 整數常數 (C2)、(C4)、(C8)
storage_max 整數常數 (C3)、(C4)、(C8)
expressed_type 浮點類型 (C1)、(C5)
quantization_dimension 選擇性整數常數 (C11-C13)
scales 變異數浮點常數 (C5-C7)、(C10)、(C11)、(C13)
zero_points 變異數整數常數 (C8-C10)

「量化元素類型」代表儲存空間類型storage_minstorage_max (含) 之間的整數值,對應於表達類型的浮點值。以指定整數值 i 來說,對應的浮點值 f 可視為 f = (i - zero_point) * scale 計算,其中 scalezero_point 稱為量化參數。文法中的 storage_minstorage_max 為選用項目,但預設值分別為 min_value(storage_type)max_value(storage_type)。量化元素類型具有下列限制:

  • (C1) num_bits(storage_type) < num_bits(expressed_type)
  • (C2) type(storage_min) = storage_type
  • (C3) type(storage_max) = storage_type
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C5) type(scales...) = expressed_type
  • (C6) 0 < scales
  • (C7) is_finite(scales...)
  • (C8) storage_min <= zero_points <= storage_max
  • (C9) type(zero_points...) = storage_type
  • (C10) size(scales) = size(zero_points)
  • (C11) 如為 is_empty(quantization_dimension),則預設為 size(scales) = 1
  • (C12) 0 <= quantization_dimension

目前,QuantizationScale 是浮點常數,但對於以整數為基礎的量表來說,以乘法和位移表示。我們計劃在未來進行探索。(#1404)。

有關 QuantizationZeroPoint 的語意持續討論,包括類型、值,以及量化張量類型中是否只能有一個或可能多個零點。根據本討論的結果,零點左右的規格日後可能會有所異動 (#1405)。

另一項持續進行中的討論涉及 QuantizationStorageMinQuantizationStorageMax 的語意,以判斷是否應對這些值和量化張量設定任何限制 (#1406)。

最後,我們打算研究代表未知量表和零點,做法類似於計劃呈現不明維度大小 (#1407)。

量化張量類型代表有量化元素的張量。這些張量與一般張量相同,但其元素具有量化元素類型,而非一般元素類型。

在量化張量中,量化可以是「各張量」,也就是為整個張量包含一個 scalezero_point,或者可以是每軸,即具有多個 scaleszero_points,每個維度 quantization_dimension 的每個配量一個對組。正式上例,在採用每軸量化的張量 t 中,quantization_dimensiondim(t, quantization_dimension) 配量:t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] 等。i 片段中的所有元素都會使用 scales[i]zero_points[i] 做為量化參數。量化張量類型具有下列限制:

  • 針對每張量量化:
    • 沒有其他限制條件。
  • 如為每軸量化:
    • (C12) quantization_dimension < rank(self)
    • (C13) dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

「權杖類型」代表符記,即部分作業產生和使用的不透明值。如執行一節所述,憑證的用途是強制執行作業的執行順序。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

元組類型代表元組,即異質清單。Tuples 是舊版功能,只為與 HLO 相容而存在。在 HLO 中,元組用於表示各變的輸入和輸出內容。在 StableHLO 中,系統原生支援變異的輸入和輸出,而且 StableHLO 中元組的唯一使用方法是全面代表 HLO ABI,例如 Ttuple<T>tuple<tuple<T>> 可能會因特定實作方式而有明顯差異。我們日後計劃變更 HLO ABI,這樣或許能從 StableHLO 移除元組類型 (#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

元素類型代表張量類型的元素。與許多程式設計語言不同,這些類型並不是 StableHLO 中的第一類別。這表示 StableHLO 程式無法直接表示這些類型的值 (因此,以 tensor<T> 類型的 0 個維度張量值表示 T 類型的純量值是慣用做法。

  • 「布林值類型」代表布林值 truefalse
  • 整數類型可以是正負號 (si) 或無正負號 (ui),而且包含支援的其中一種位元寬度 (48163264)。已簽署的 siN 類型代表從 -2^(N-1)2^(N-1)-1 的整數值,無正負號 uiN 類型代表從 02^N-1 之間的整數值。
  • 浮點類型可以是下列任一值:
  • 「複雜類型」代表具有「實際部分」以及屬於相同元素類型虛部分的複雜值。支援的複雜類型為 complex<f32> (這兩個部分都是 f32 類型) 和 complex<f64> (這兩個部分都是 f64 類型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

函式類型代表已命名和匿名函式。它們具有輸入類型 (-> 左側的類型清單) 和輸出類型 (-> 右側的類型清單)。在許多程式設計語言中,函式類型都是第一類,但在 StableHLO 中則不是。

StringType ::= 'string'

字串類型代表位元組序列。與許多程式設計語言不同,字串類型並非 StableHLO 中的第一個類別,只會用來為程式元素指定靜態中繼資料。

作業套件

StableHLO 作業 (也稱為「作業」) 代表機器學習模型中的一組高階作業。如上所述,StableHLO 語法非常受到 MLIR 啟發,這不一定是最能體工學的替代方法,但或許最適合 StableHLO 的目標是在機器學習架構與機器學習編譯器之間建立更互通性的目標。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO 作業 (也稱為「操作」) 具有名稱、輸入/輸出和簽名。名稱包含 stablehlo. 前置字元和一個可明確識別其中一項支援的運算的記憶。請參閱下方的完整支援作業清單。

目前,野外的 StableHLO 程式有時會包含本文未提及的作業。未來我們打算將這些作業吸收到 StableHLO 運算組中,或禁止這些作業出現在 StableHLO 程式中。與此同時,以下是這些作業的清單:

  • builtin.modulefunc.funcfunc.callfunc.return (#425)。
  • chlo 運算 (#602)。
  • 「不在 HLO」 StableHLO 運算類別 - 最初是 StableHLO 運算集的一部分,但後來被視為不適合其:broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum (#3)。
  • StableHLO 運算的「Dynamism」類別是由 MHLO 自行啟動,但我們尚未找到:compute_reshape_shapecstr_reshapabledynamic_broadcast_in_dimdynamic_convdynamic_gatherdynamic_iotadynamic_paddynamic_reshapereal_dynamic_sliceset_dimension_size (#8)。
  • 形狀計算,包括 arithshapetensor 運算 (#8)。
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

作業會取用輸入並產生「輸出」。輸入內容分為輸入值 (執行期間計算)、輸入函式 (因為在 StableHLO 函式中並非第一類別值) 和輸入屬性 (同樣以靜態方式提供)。運算使用和產生的輸入與輸出種類取決於其記憶。舉例來說,add 運算會使用 2 個輸入值並產生 1 個輸出值。相較之下,select_and_scatter 運算會使用 3 個輸入值、2 個輸入函式和 3 個輸入屬性。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

輸入函式 (又稱為「匿名函式」) 與已命名函式非常類似,但前者沒有 ID (因此稱為「匿名」),2) 它們不會宣告輸出類型 (從函式中的 return 運算推斷輸出類型)。

輸入函式的語法包含目前未使用的部分 (請參閱上方的 Unused 實際工作環境),該部分可與 MLIR 相容。在 MLIR 中,有一個較通用的「地區」概念,可以有多個「區塊」透過跳躍運算連結。這些區塊的 ID 與 Unused 實際工作環境相對應,因此可以彼此區分。StableHLO 沒有跳躍作業,因此不會使用 MLIR 語法的對應部分 (但仍然存在)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

輸入屬性具有名稱和值,這是支援的常數之一。是為節目元素指定靜態中繼資料的主要方式。舉例來說,concatenate 運算會使用 dimension 屬性指定串連輸入值的維度。同樣地,slice 運算會使用 start_indiceslimit_indices 等多個屬性來指定用於切分輸入值的邊界。

目前,野外的 StableHLO 程式有時會包含本文未說明的屬性。未來我們打算將這些屬性吸收到 StableHLO 運算組中,或禁止這些屬性出現在 StableHLO 程式中。同時,以下是這些屬性的清單:

  • layout (#629)。
  • mhlo.frontend_attributes (#628)。
  • mhlo.sharding (#619)。
  • output_operand_aliases (#740)。
  • 位置中繼資料 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

「運算簽名」由所有輸入值的類型 (-> 左側的類型清單) 和所有輸出值的類型 (-> 右側的類型清單) 組成。嚴格說來,輸入類型都是多餘的,輸出類型也幾乎都是多餘的 (因為多數 StableHLO 運算的輸出類型可根據輸入推論)。不過,運算簽名是刻意安排為 StableHLO 語法的一部分,以與 MLIR 相容。

以下是記憶為 select_and_scatter 的運算範例。這個公式會使用 3 個輸入值 (%operand%source%init_value)、2 個輸入函式和 3 個輸入屬性 (window_dimensionswindow_stridespadding)。請注意,運算的簽章僅包含其輸入值的類型 (但不含內嵌提供的輸入函式和屬性類型)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

常數

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 常數具有常值和類型,一起代表 StableHLO 值。一般來說,這個類型是常數語法的一部分,除非是不明確的情況 (例如布林值常數含有 i1 類型,而整數常數可以有多種可能類型)。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

布林值常數代表布林值 truefalse。布林常數具有 i1 類型。

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整數常數會透過使用十進位或十六進位標記法的字串來表示整數值。系統不支援其他底數 (例如二進位或八進位)。 整數常數具有下列限制:

  • (C1) is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮點常數透過使用小數或科學標記法的字串表示浮點值。此外,十六進位標記法也可用來直接指定對應類型浮點格式的基礎位元。浮點常數具有下列限制:

  • (C1) 如果使用非十六進位標記法,請設為 is_wellformed(float_literal, float_type)
  • (C2) 如果使用十六進位標記法,則 size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

複雜常數會使用實際部分 (先推出) 和虛部分 (第二) 的清單表示複雜的值。例如,(1.0, 0.0) : complex<f32> 代表 1.0 + 0.0i(0.0, 1.0) : complex<f32> 則代表 0.0 + 1.0i。這些部分在記憶體中的儲存順序是由實作定義。複雜常數具有下列限制:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type))
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Tensor 常數表示使用透過 NumPy 標記指定的巢狀清單表示張量值。舉例來說,dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 代表張量值,下列從索引對應至元素:{0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6。這些元素會在記憶體中的儲存順序實作。張量常數具有下列限制:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)),其中:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)),其中:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • 否則為 false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

「量化張量常數」代表量化張量值,並使用與張量常數相同的標記法表示,而且元素會指定為其儲存空間類型的常數。量化張量常數具有下列限制:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

字串常值由使用 ASCII 字元和逸出序列指定的位元組組成。這些位元組各不具有編碼功能,因此對這些位元組的解釋是由實作定義。字串常值具有 string 類型。

作業數

abs

語義

operand 張量執行元素層級抽象運算,並產生 result 張量。根據元素類型執行以下操作:

  • 帶正負號整數:整數模數。
  • 浮點值:由 IEEE-754 距離 abs
  • 複數:複數的模數。
  • 量化類型:dequantize_op_quantize(abs, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 (C1-C2)

輸出

名稱 類型 限制
result 帶正負號整數、浮點類型,或每個張量量化張量的張量 (C1-C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) baseline_element_type(result) 定義為:
    • 如果 is_complex(operand) 則為 complex_element_type(element_type(operand))
    • 否則傳回 baseline_element_type(operand)

範例

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

更多範例

add

語義

對兩個張量 lhsrhs 執行元素相關新增作業,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 OR。
  • 整數:加整數。
  • 浮點值:由 IEEE-754 距離 addition
  • 複數:複數的加法。
  • 量化類型:dequantize_op_quantize(add, lhs, rhs, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1)。
(I2)。 rhs 張量或每個張量化張量 (C1)。

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

更多範例

after_all

語義

確保產生 inputs 的作業在依附 result 的任何作業之前執行。這項作業不會執行任何動作,系統只會建立從 resultinputs 的資料依附元件。

輸入內容

標籤 名稱 類型
(I1)。 inputs 變異數token

輸出

名稱 類型
result token

範例

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

更多範例

all_gather

語義

在 StableHLO 程序格線中的每個程序群組中,按照 all_gather_dim 將每個程序的 operand 張量值串連起來,然後產生 result 張量。

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • cross_replica(replica_groups) 如果為 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = true

之後,在每個 process_group 中:

  • operands@receiver = [operand@sender for sender in process_group] 適用於 process_group 中所有 receiver
  • result@process = concatenate(operands@process, all_gather_dim) 適用於 process_group 中所有 process

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C6)
(I2)。 all_gather_dim si64 類型的常數 (C1)、(C6)
(I3)。 replica_groups si64 類型的 2 維張量常數 (C2-C4)
(I4)。 channel_id si64 類型的常數 (C5)
(I5) use_global_device_ids i1 類型的常數 (C5)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C6)

限制

  • (C1) 0 <= all_gather_dim < rank(operand)
  • (C2) is_unique(replica_groups)
  • (C3) size(replica_groups) 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C4) 0 <= replica_groups < size(replica_groups)
  • (C5) 如為 use_global_device_ids = true,則預設為 channel_id > 0
  • (C6) type(result) = type(operand),但下列項目除外:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

範例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

更多範例

all_reduce

語義

在 StableHLO 程序格線中的每個程序群組中,將縮減函式 computation 套用至每個程序中的 operand 張量值,然後產生 result 張量。

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • cross_replica(replica_groups) 如果為 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = true

之後,在每個 process_group 中:

  • result@process[result_index] = exec(schedule) 代表二進位樹狀結構 schedule,其中:
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是實作定義的二進位樹狀結構,其順序週遊為 to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C5)、(C6)
(I2)。 replica_groups si64 類型的 1 維張量常數的變異數 (C1-C3)
(I3)。 channel_id si64 類型的常數 (C4)
(I4)。 use_global_device_ids i1 類型的常數 (C4)
(I5) computation 函式 (C5)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C6-C7)

限制

  • (C1) is_unique(replica_groups)
  • (C2) size(replica_groups) 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C3) 0 <= replica_groups < size(replica_groups)
  • (C4) 如為 use_global_device_ids = true,則預設為 channel_id > 0
  • (C5) computation 具有類型 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C6) shape(result) = shape(operand)
  • (C7) element_type(result) = E

範例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

更多範例

all_to_all

語義

在 StableHLO 程序格線的每個程序群組中,將 operand 張量的值沿著 split_dimension 分割為部分,將分割部分分散到程序之間,沿著 concat_dimension 串連分散的部分,然後產生 result 張量。

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • 如果 channel_id <= 0 則為 cross_replica(replica_groups)
  • 如果 channel_id > 0 則為 cross_partition(replica_groups)

之後,在每個 process_group 中:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) 適用於 process_group 中所有 sender
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group],即 receiver_index = process_group.index(receiver)
  • result@process = concatenate(scattered_parts@process, concat_dimension).

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1-C3)、(C9)
(I2)。 split_dimension si64 類型的常數 (C1)、(C2)、(C9)
(I3)。 concat_dimension si64 類型的常數 (C3)、(C9)
(I4)。 split_count si64 類型的常數 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups si64 類型的 2 維張量常數 (C5-C8)
(I6)。 channel_id si64 類型的常數

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C9)

限制

  • (C1) 0 <= split_dimension < rank(operand)
  • (C2) dim(operand, split_dimension) % split_count = 0
  • (C3) 0 <= concat_dimension < rank(operand)
  • (C4) 0 < split_count
  • (C5) is_unique(replica_groups)
  • (C6) size(replica_groups) 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C7) 0 <= replica_groups < size(replica_groups)
  • (C8) dim(replica_groups, 1) = split_count
  • (C9) type(result) = type(operand),但下列項目除外:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

範例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

更多範例

以及

語義

lhsrhs 兩個張量為目標執行元素 AND,然後產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 AND。
  • 整數:位元 AND。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 布林值或整數類型的張量 (C1)。
(I2)。 rhs 布林值或整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

語義

lhsrhs 張量執行元素相關 atan2 作業,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 atan2
  • 複數:複數:複數。
  • 量化類型:dequantize_op_quantize(atan2, lhs, rhs, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 浮點、複雜類型或每個張量化張量的張量 (C1)。
(I2)。 rhs 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

更多範例

batch_norm_grad

語義

計算從 grad_output 反向傳播 batch_norm_training 的數個輸入梯度,並產生 grad_operandgrad_scalegrad_offset 張量。更正式地說,這個運算可透過 Python 語法以分解的形式表示到現有 StableHLO 作業,如下所示:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

如果是量化類型,請執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1-C3)、(C5)
(I2)。 scale 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)、(C5)
(I3)。 mean 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I4)。 variance 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I5) grad_output 浮點類型或每個張量量化張量的張量 (C2)、(C3)
(I6)。 epsilon f32 類型的常數
(I7) feature_index si64 類型的常數 (C1)、(C5)

輸出

名稱 類型 限制
grad_operand 浮點類型或每個張量量化張量的張量 (C2)、(C3)
grad_scale 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
grad_offset 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offset 有相同的 baseline_element_type
  • (C3) operandgrad_outputgrad_operand 形狀相同。
  • (C4) scalemeanvariancegrad_scalegrad_offset 的形狀相同。
  • (C5) size(scale) = dim(operand, feature_index)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

語義

將所有維度的 operand 張量正規化 (feature_index 維度除外),並產生 result 張量。更正式的,這個運算可透過 Python 語法以分解的形式表示到現有 StableHLO 作業,如下所示:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

如果是量化類型,請執行 dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1-C7)
(I2)。 scale 浮點或每個張量量化類型的 1 維張量 (C2)、(C3)
(I3)。 offset 浮點或每個張量量化類型的 1 維張量 (C2)、(C4)
(I4)。 mean 浮點或每個張量量化類型的 1 維張量 (C5)
(I5) variance 浮點或每個張量量化類型的 1 維張量 (C2)、(C6)
(I6)。 epsilon f32 類型的常數
(I7) feature_index si64 類型的常數 (C1)、(C3-C6)

輸出

名稱 類型 限制
result 浮點類型或每個張量量化張量的張量 (C2)、(C7)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetmeanvarianceresult 有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(mean) = dim(operand, feature_index)
  • (C6) size(variance) = dim(operand, feature_index)
  • (C7) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

語義

計算所有維度的平均值和變異數,但 feature_index 維度除外,並將產生 outputbatch_meanbatch_var 張量的 operand 張量正規化。更正式的做法是,使用 Python 語法以分解為現有 StableHLO 作業的形式表示,如下所示:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

如果是量化類型,請執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。
(I2)。 scale 浮點或每個張量的 1 維張量 (C2)、(C3)
(I3)。 offset 浮點或每個張量的 1 維張量 (C2)、(C4)
(I4)。 epsilon f32 類型的常數 (C1)、(C3-C6)
(I5) feature_index si64 類型的常數 (C1)、(C3-C6)

輸出

名稱 類型 限制
output 浮點類型或每個張量量化張量的張量 (C7)
batch_mean 浮點或每個張量的 1 維張量 (C2)、(C5)
batch_var 浮點或每個張量的 1 維張量 (C2)、(C6)

限制

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetbatch_meanbatch_varoutput 有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(batch_mean) = dim(operand, feature_index)
  • (C6) size(batch_var) = dim(operand, feature_index)
  • (C7) baseline_type(output) = baseline_type(operand)

範例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

語義

operand 張量執行 Bitcast 作業,並產生 result 張量,其中整個 operand 張量的位元會使用 result 張量類型重新解譯。

更正式,有 E = element_type(operand)E' = element_type(result)R = rank(operand)

  • 如果值為 num_bits(E') < num_bits(E),則為 bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • 如果值為 num_bits(E') > num_bits(E),則為 bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • 如果值為 num_bits(E') = num_bits(E),則為 bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits 會傳回指定值的記憶體內表示法,且其行為是由實作定義,因為實作方式已定義實際代表張量,且元素類型的確切表示法也經過實作定義。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或量化張量 (C1-C2)

輸出

名稱 類型 限制
result 張量或量化張量 (C1-C2)

限制

  • (C1) 如 E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand)
    • 如果值為 num_bits(E') = num_bits(E),則值為 shape(result) = shape(operand)
    • 如為 num_bits(E') < num_bits(E)
    • rank(result) = R + 1.
    • 所有 0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(result, R) * num_bits(E') = num_bits(E).
    • 如為 num_bits(E') > num_bits(E)
    • rank(result) = R - 1.
    • 所有 0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) 如果為 is_complex(operand) or is_complex(result),則 is_complex(operand) and is_complex(result)

範例

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

更多範例

broadcast_in_dim

語義

複製 operand 張量中的資料並產生 result 張量,以展開輸入張量的維度和/或排名。更正式的 result[result_index] = operand[operand_index],其中 axes(operand) 中所有 d

  • 如果 dim(operand, d) = 1 則為 operand_index[d] = 0
  • 否則傳回 operand_index[d] = result_index[broadcast_dimensions[d]]

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或量化張量 (C1-C2)、(C5-C6)
(I2)。 broadcast_dimensions si64 類型的 1 維張量常數 (C2-C6)

輸出

名稱 類型 限制
result 張量或量化張量 (C1)、(C3)、(C5-C6)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand) (如果為 !is_per_axis_quantized(operand))。
    • element_type(operand) 除外,但 quantization_dimension(operand)scales(operand)zero_points(operand) 可能與 quantization_dimension(result)scales(result)zero_points(result) 不同。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 針對 axes(operand) 中的所有 d
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) 如果為 is_per_axis_quantized(result)
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • 如果為 dim(operand, quantization_dimension(operand)) = 1,則值為 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))

範例

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多範例

客服案件

語義

根據 index 的值,從 branches 僅執行一個函式產生輸出內容。正式來說,result = selected_branch() 其中:

  • 如果 0 <= index < size(branches) 則為 selected_branch = branches[index]
  • 否則傳回 selected_branch = branches[-1]

輸入內容

標籤 名稱 類型 限制
(I1)。 index si32 類型的 0 維張量
(I2)。 branches 變異函數數量 (C1-C4)

輸出

名稱 類型 限制
results 各種張量、量化張量或符記 (C4)

限制

  • (C1) 0 < size(branches)
  • (C2) input_types(branches...) = []
  • (C3) same(output_types(branches...))
  • (C4) type(results...) = output_types(branches[0])

範例

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

更多範例

Cbrt

語義

operand 張量執行元素相關立方根作業,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 rootn(x, 3)
  • 複數:複雜的立方根。
  • 量化類型:dequantize_op_quantize(cbrt, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

更多範例

水泥

語義

執行 operand 張量的元素圓形,並產生 result 張量。從 IEEE-754 規格實作 roundToIntegralTowardPositive 作業。如果是量化類型,請執行 dequantize_op_quantize(ceil, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

更多範例

白洞

語義

計算一批矩陣的膽天分解法。

更確切地說,對於 index_space(result) 中的所有 iresult[i0, ..., iR-3, :, :]a[i0, ..., iR-3, :, :] 的孔洞分解,採用下三角形 (如果 lowertrue) 或上三角形 (如果 lowerfalse) 矩陣。相對三角形的輸出值 (對應所對應的嚴格上限三角形或嚴格三角形) 的輸出值是實作定義。

如果有 i 的輸入矩陣不是隱士頓陽性矩陣,則行為將未定義。

如果是量化類型,請執行 dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 a 浮點、複雜類型或每個張量化張量的張量 (C1-C3)
(I2)。 lower i1 類型的 0 維張量常數

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(a) = baseline_type(result)
  • (C2) 2 <= rank(a)
  • (C3) dim(a, -2) = dim(a, -1)

範例

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

限制取值範圍

語義

operand 張量的每個元素限制在最小值和最大值之間,產生 result 張量。更正式的是 result[result_index] = minimum(maximum(operand[result_index], min_element), max_element),其中 min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index]。如果是量化類型,則執行 dequantize_op_quantize(clamp, min, operand, max, type(result))

在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。

輸入內容

標籤 名稱 類型 限制
(I1)。 min 張量或每個張量化張量 (C1)、(C3)
(I2)。 operand 張量或每個張量化張量 (C1-C4)
(I3)。 max 張量或每個張量化張量 (C2)、(C3)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C4)

限制

  • (C1) rank(min) = 0 or shape(min) = shape(operand)
  • (C2) rank(max) = 0 or shape(max) = shape(operand)
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4) baseline_type(operand) = baseline_type(result)

範例

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

更多範例

collective_broadcast

語義

在 StableHLO 程序格線中的每個程序群組中,將 operand 張量的值從來源程序傳送至目標程序,並產生 result 張量。

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • 如果 channel_id <= 0 則為 cross_replica(replica_groups)
  • 如果 channel_id > 0 則為 cross_partition(replica_groups)

之後,result@process 將由以下提供:

  • 如果有 i,因此程序位於 process_groups[i],則為 operand@process_groups[i, 0]
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)) 否則。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量 (C3)
(I2)。 replica_groups si64 類型的 1 維張量常數的變異數 (C1)、(C2)
(I3)。 channel_id si64 類型的常數

輸出

名稱 類型 限制
result 張量 (C3)

限制

  • (C1) is_unique(replica_groups)
  • (C2) 0 <= replica_groups < N,其中 N 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C3) type(result) = type(operand)

範例

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

語義

在 StableHLO 程序格線中的每個程序群組中,將 operand 張量的值從來源程序傳送至目標程序,並產生 result 張量。

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • 如果 channel_id <= 0 則為 cross_replica(source_target_pairs)
  • 如果 channel_id > 0 則為 cross_partition(source_target_pairs)

之後,result@process 將由以下提供:

  • operand@process_groups[i, 0] (如果 i 存在這類 process_groups[i, 1] = process)。
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) 否則。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C5)
(I2)。 source_target_pairs si64 類型的 2 維張量常數 (C1-C4)
(I3)。 channel_id si64 類型的常數

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) dim(source_target_pairs, 1) = 2
  • (C2) is_unique(source_target_pairs[:, 0])
  • (C3) is_unique(source_target_pairs[:, 1])
  • (C4) 0 <= source_target_pairs < N,其中 N 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_partition,則為 num_partitions
  • (C5) type(result) = type(operand)

範例

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

更多範例

compare

語義

根據 comparison_directioncompare_typelhsrhs 張量執行元素層級比較,並產生 result 張量。

comparison_directioncompare_type 的值包含下列語意:

布林值和整數元素類型:

  • EQlhs = rhs
  • NElhs != rhs
  • GElhs >= rhs
  • GTlhs > rhs
  • LElhs <= rhs
  • LTlhs < rhs

如果是含有 compare_type = FLOAT 的浮點元素類型,運算會實作下列 IEEE-754 作業:

  • EQcompareQuietEqual
  • NEcompareQuietNotEqual
  • GEcompareQuietGreaterEqual
  • GTcompareQuietGreater
  • LEcompareQuietLessEqual
  • LTcompareQuietLess

如果是具有 compare_type = TOTALORDER 的浮點元素類型,這項運算會使用 IEEE-754 的 totalOrdercompareQuietEqual 運算。這項功能似乎未使用,因此我們計劃在日後將其移除 (#584)。

如果是複雜的元素類型,系統會使用提供的 comparison_directioncompare_type(real, imag) 組合的字母組合比較。在複雜數字上排序複雜數字會帶來令人意想不到的語意,因此我們未來計劃在 comparison_directionGEGTLELT 時移除對複數的支援 (#560)。

如果是量化類型,則執行 dequantize_compare(lhs, rhs, comparison_direction)

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1-C3)
(I2)。 rhs 張量或每個張量化張量 (C1-C2)
(I3)。 comparison_direction EQNEGEGTLELT 的列舉項目
(I4)。 compare_type FLOATTOTALORDERSIGNEDUNSIGNED 的列舉 (C3)

輸出

名稱 類型 限制
result 布林類型的張量 (C2)

限制

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2) shape(lhs) = shape(rhs) = shape(result)
  • (C3) compare_type 定義為:
    • 如果 is_signed_integer(element_type(lhs)) 則為 SIGNED
    • 如果 is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)) 則為 UNSIGNED
    • FLOATTOTALORDER 表示 is_float(element_type(lhs))
    • 如果 is_complex(element_type(lhs)) 則為 FLOAT

範例

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

更多範例

複雜

語義

從實際和虛值 (lhsrhs) 的一對複雜值執行元素相關轉換,然後產生 result 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs f32f64 類型的張量 (C1-C3)
(I2)。 rhs f32f64 類型的張量 (C1)。

輸出

名稱 類型 限制
result 複雜類型的張量 (C2)、(C3)

限制

  • (C1) type(lhs) = type(rhs)
  • (C2) shape(result) = shape(lhs)
  • (C3) element_type(result) 具有類型 complex<E>,其中 E = element_type(lhs)

範例

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

更多範例

concatenate

語義

按照指定引數的順序,將 inputs 沿著 dimension 維度串連,並產生 result 張量。更正式的 result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1],其中:

  1. id = d0 + ... + dk-1 + kd.
  2. d 等於 dimension,而 d0d inputs 的維度大小。

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1-C6)
(I2)。 dimension si64 類型的常數 (C2)、(C4)、(C6)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C5-C6)

限制

  • (C1) same(element_type(inputs...))
  • (C2) same(shape(inputs...))dim(inputs..., dimension) 除外。
  • (C3) 0 < size(inputs)
  • (C4) 0 <= dimension < rank(inputs[0])
  • (C5) element_type(result) = element_type(inputs[0])
  • (C6) shape(result) = shape(inputs[0]),但以下項目除外:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

範例

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

更多範例

常數

語義

從常數 value 產生 output 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 value 常數 (C1)。

輸出

名稱 類型 限制
output 張量或量化張量 (C1)。

限制

  • (C1) type(value) = type(output)

範例

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

更多範例

完成轉換

語義

這個外掛程式能在 operand 張量上,從某個元素類型到另一種元素類型轉換,然後產生 result 張量。

如果是 boolean-to-any-supported-type 轉換,系統會將 false 值轉換為零,而 true 值會轉換成一。針對any-supported-type-to-boolean轉換,0 值會轉換為 false,非零值則會轉換為 true。請參閱下文瞭解這種做法在複雜類型中的運作方式。

如果轉換包含整數到整數整數到浮點點floating-point-to-floating-point,如果來源值可在目的地類型中精準呈現,結果值就是確切的表示法。否則行為為待定。(#180)。

如果轉換涉及floating-point-to-integer,小數部分會遭到截斷。如果目的地類型無法呈現截斷的值,則行為為未定 (#180)。

如果轉換涉及複雜至複雜,其行為與浮點對浮點點轉換行為相同,可用於轉換實際部分和虛部分。

針對complex-to-any-other-typecomplex-to-any-other-type 轉換,系統會忽略來源虛值,或目的地虛值的值為零。實際轉換次數會跟著浮點轉換

原則上,這個運算可以表示去量化 (從量化張量轉換為一般張量)、量化 (從一般張量轉換為量化張量) 和重新量化 (量化張量之間的轉換),但目前我們針對第一個用途有 uniform_dequantize 的專屬運算,第二和第三種用途則為 uniform_quantize。這兩個運算日後可能會合併為 convert (#1576)。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量 (C1)。

輸出

名稱 類型 限制
result 張量 (C1)。

限制

  • (C1) shape(operand) = shape(result)

範例

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

更多範例

卷積

語義

計算 lhs 視窗和 rhs 配量之間的點積,並產生 result。下圖顯示如何使用具體範例,從 lhsrhs 計算 result 中的元素。

更正式的做法是以下列為 lhs 修正輸入內容,以便表達 lhs 的區間:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

這項參照使用下列輔助函式:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1],即 j[d] = i[permutation[d]]

如果為 feature_group_count = 1batch_group_count = 1,則針對 index_space(dim(result, output_spatial_dimensions...)) 中的所有 output_spatial_indexresult[result_shape(:, output_spatial_index, :)] = dot_product 其中:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。這項功能似乎未使用,因此我們預計在日後將其移除 (#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

如為 feature_group_count > 1

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

如為 batch_group_count > 1

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension)

如果是量化類型,請執行 dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1)、(C10-C11)、(C14) (C25)、(C27-C30)
(I2)。 rhs 張量或量化張量 (C1)、(C14-C16)、(C25)、(C27-C32)
(I3)。 window_strides si64 類型的 1 維張量常數 (C2-C3)、(C25)
(I4)。 padding si64 類型的 2 維張量常數 (C4)、(C25)
(I5) lhs_dilation si64 類型的 1 維張量常數 (C5-C6)、(C25)
(I6)。 rhs_dilation si64 類型的 1 維張量常數 (C7-C8)、(C25)
(I7) window_reversal i1 類型的 1 維張量常數 (C9)
(I8) input_batch_dimension si64 類型的常數 (C10)、(C13)、(C25)
(I9) input_feature_dimension si64 類型的常數 (C11)、(C13-C14)
(I10)。 input_spatial_dimensions si64 類型的 1 維張量常數 (C12)、(C13)、(C25)
(I11)。 kernel_input_feature_dimension si64 類型的常數 (C14)、(C18)
(I12)。 kernel_output_feature_dimension si64 類型的常數 (C15-C16)、(C18)、(C25)、(C32)
(I13)。 kernel_spatial_dimensions si64 類型的 1 維張量常數 (C17-C18)、(C25)
(I14)。 output_batch_dimension si64 類型的常數 (C20)、(C25)
(I15) output_feature_dimension si64 類型的常數 (C20)、(C25)、(C33)
(I16)。 output_spatial_dimensions si64 類型的 1 維張量常數 (C19-C20)、(C25)
(I17)。 feature_group_count si64 類型的常數 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 類型的常數 (C10)、(C15)、(C22)、(C23)、(C25)
(I19) precision_config DEFAULTHIGHHIGHEST 的列舉數量 (C24)

輸出

名稱 類型 限制
result 張量或量化張量 (C25-C28)、(C30-C31)、(C33)

限制

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 為 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) 為 kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 為 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 定義為:
    • 如果 result_dim = output_batch_dimension 則為 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension 則為 dim(rhs, kernel_output_feature_dimension)
    • num_windows 否則,其中:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N
  • 如果作業使用非量化張量:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果作業使用量化張量:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
    • (C29) storage_type(lhs) = storage_type(rhs)
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C31) 如為 is_per_tensor_quantized(rhs),則值為 is_per_tensor_quantized(result)
    • (C32) 如為 is_per_axis_quantized(rhs),則值為 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C33) 如為 is_per_axis_quantized(result),則值為 quantization_dimension(result) = output_feature_dimension

範例

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

餘弦

語義

operand 張量執行元素相關餘弦運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 cos
  • 複數:複數的餘弦。
  • 量化類型:dequantize_op_quantize(cosine, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

更多範例

count_leading_zeros

語義

針對 operand 張量中的前置零位元執行元素相關計數,並產生 result 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數類型的張量 (C1)。

限制

  • (C1) type(operand) = type(result)

範例

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

更多範例

custom_call

語義

封裝實作定義作業 call_target_name,該作業採用 inputscalled_computations 並產生 resultshas_side_effectbackend_configapi_version 可用於提供其他的實作定義中繼資料。

目前,這項作業包含一系列結構相當多元的中繼資料集合,反映了其在 XLA 編譯器中對應作業的自然演進。我們計劃日後統合這類中繼資料。(#741)。

輸入內容

標籤 名稱 類型
(I1)。 inputs 變異數值
(I2)。 call_target_name string 類型的常數
(I3)。 has_side_effect i1 類型的常數
(I4)。 backend_config string 類型的常數
(I5) api_version si32 類型的常數
(I6)。 called_computations string 類型的常數數

輸出

名稱 類型
results 變異數值

範例

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

語義

執行被除 lhs 和除數 rhs 張量的元素相關除法,並產生 result 張量。根據元素類型執行以下操作:

  • 整數:整數除法,產生捨棄任何小數部分的代數商。
  • 浮點值:由 IEEE-754 距離 division
  • 複數:複數。
  • 量化類型:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。
(I2)。 rhs 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。

輸出

名稱 類型 限制
result 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

更多範例

dot_general

語義

計算 lhs 切片和 rhs 配量之間的內積,並產生 result 張量。

更正式的 result[result_index] = dot_product,其中:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index 其中 size(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

如果是量化類型,請執行 dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))

這個方法只會指定每個張量量化的語意。每軸量化目前正在處理中 (#1574)。此外,我們日後可能會考慮新增對混合量化支援 (#1575)。

precision_config 可控制加速器後端運算速度與準確率之間的取捨。可能的值如下 (於當下,這些列舉值的語意尚未指定,但我們計劃在 #755 中解決這個問題):

  • DEFAULT:最快計算,但近似原始數字的精確度較不精確。
  • HIGH:計算速度較慢,但近似原始數字的估計值會更準確。
  • HIGHEST:計算速度最慢,但近似原始數字最準確的預測結果。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C5-C6)、(C9-C10)、(C12-C16)
(I2)。 rhs 張量或每個張量化張量 (C7-C10)、(C12)
(I3)。 lhs_batching_dimensions si64 類型的 1 維張量常數 (C1)、(C3)、(C5)、(C9)、(C12)
(I4)。 rhs_batching_dimensions si64 類型的 1 維張量常數 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions si64 類型的 1 維張量常數 (C2)、(C3)、(C6)、(C10)
(I6)。 rhs_contracting_dimensions si64 類型的 1 維張量常數 (C2)、(C4)、(C8)、(C10)
(I7) precision_config DEFAULTHIGHHIGHEST 的列舉數量 (C11)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C12)、(C14)、(C16)

限制

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs)
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs)
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11) size(precision_config) = 2
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • 如果作業使用非量化張量:
    • (C13) element_type(lhs) = element_type(rhs)
  • 如果作業使用量化張量:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C15) storage_type(lhs) = storage_type(rhs)
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C17) zero_points(rhs) = 0

範例

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

更多範例

dynamic_slice

語義

使用動態計算的啟動索引從 operand 擷取切片,然後產生 result 張量。start_indices 包含每個維度的切片起始索引,slice_sizes 包含每個維度的切片大小。更正式的 result[result_index] = operand[operand_index],其中:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C2)、(C4)
(I2)。 start_indices 整數類型的 0 維張量變異數 (C2)、(C3)
(I3)。 slice_sizes si64 類型的 1 維張量常數 (C2)、(C4)、(C5)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)、(C5)

限制

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3) same(type(start_indices...))
  • (C4) 0 <= slice_sizes <= shape(operand)
  • (C5) shape(result) = slice_sizes

範例

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

更多範例

dynamic_update_slice

語義

產生等於 operand 張量的 result 張量,但從 start_indices 開始的切片會以 update 中的值更新。更正式的,result[result_index] 定義為:

  • 如果 0 <= update_index < shape(update) 且符合以下條件,則為 update[update_index]
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • 否則傳回 operand[result_index]

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1-C4)、(C6)
(I2)。 update 張量或每個張量化張量 (C2)、(C3)、(C6)
(I3)。 start_indices 整數類型的 0 維張量變異數 (C4)、(C5)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) type(operand) = type(result)
  • (C2) element_type(update) = element_type(operand)
  • (C3) rank(update) = rank(operand)
  • (C4) size(start_indices) = rank(operand)
  • (C5) same(type(start_indices...))
  • (C6) 0 <= shape(update) <= shape(operand)

範例

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

更多範例

指數

語義

operand 張量執行元素相關指數運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 exp
  • 複數:複數的指數。
  • 量化類型:dequantize_op_quantize(exponential, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

更多範例

exponential_minus_one

語義

operand 張量執行元素指數減去一次作業,然後產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 expm1
  • 複數:複數的指數減 1。
  • 量化類型:dequantize_op_quantize(exponential_minus_one, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

更多範例

fft

語義

針對真實複雜的輸入/輸出執行向前和反向 Fourier 轉換。

fft_type 是下列其中一項:

  • FFT:將複雜到複雜的 FFT。
  • IFFT:複數到複雜 FFT。
  • RFFT:正向複雜 FFT。
  • IRFFT:非實際為複雜 FFT (即需要複雜、傳回真實)。

更正式的說,由於 fft 函式會取用複雜類型的 1 維張量做為輸入,會產生與輸出相同類型的 1 維張量,並計算離散的 Fourier 轉換:

如果是 fft_type = FFTresult 是一系列 L 運算的最終結果,其中 L = size(fft_length)。例如,L = 3

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

此外,由於 ifft 函式的類型簽章相同,且會計算 fft 的反向運算:

對於 fft_type = IFFTresult 定義為 fft_type = FFT 運算的反函式。例如,L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

此外,由於 rfft 函式會取用浮點類型的 1 維張量,因此會產生相同浮點語意的 1 維複雜張量,且運作方式如下:

  • rfft(real_operand) = truncated_result,其中
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(為實際運算元計算獨立的 Fourier 轉換時,結果的第一個 N/2 + 1 元素會明確定義其餘結果,因此系統會截斷 rfft 的結果,避免計算多餘的元素。

如果是 fft_type = RFFTresult 是一系列 L 運算的最終結果,其中 L = size(fft_length)。例如,L = 3

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

最後,假設 irfft 函式的類型簽章相同,且會計算 rfft 的反向運算:

對於 fft_type = IRFFTresult 定義為 fft_type = RFFT 運算的反函式。例如,L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)、(C4)、(C5)
(I2)。 fft_type FFTIFFTRFFTIRFFT 的列舉 (C2)、(C5)
(I3)。 fft_length si64 類型的 1 維張量常數 (C1)、(C3)、(C4)

輸出

名稱 類型 限制
result 浮點或複雜類型的張量 (C2)、(C4)、(C5)

限制

  • (C1) size(fft_length) <= rank(operand)
  • (C2) operandresult 元素類型之間的關係如下:
    • 如果 fft_type = FFTelement_type(operand)element_type(result) 的類型相同,
    • 如果 fft_type = IFFTelement_type(operand)element_type(result) 的類型相同,
    • 如果 fft_type = RFFTelement_type(operand) 是浮點類型,element_type(result) 是相同浮點語意的複雜類型。
    • 如果 fft_type = IRFFTelement_type(operand) 是複雜類型,element_type(result) 是相同浮點語意的浮點類型。
  • (C3) 1 <= size(fft_length) <= 3
  • (C4) 如果介於 operandresult 之間,則會有浮點類型的張量 real,然後是 shape(real)[-size(fft_length):] = fft_length
  • (C5) shape(result) = shape(operand),但以下項目除外:
    • 如果值為 fft_type = RFFT,則為 dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • 如果值為 fft_type = IRFFT,則為 dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

範例

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

語義

執行 operand 張量的元素樓層,並產生 result 張量。從 IEEE-754 規格實作 roundToIntegralTowardNegative 作業。如果是量化類型,請執行 dequantize_op_quantize(floor, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

更多範例

收集

語義

start_indices 中指定的偏移值收集 operand 張量,並產生 result 張量。

下圖以具體範例說明 result 中的元素如何對應到 operand 中的元素。圖中挑選了一些 result 索引範例,並詳細說明它們對應的 operand 索引。

更正式的 result[result_index] = operand[operand_index],其中:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index 定義為:
    • start_indices[bi0, ..., :, ..., biN],其中 bibatch_index 中的個別元素,如果 index_vector_dim < rank(start_indices),則 : 會插入 index_vector_dim 索引。
    • 否則傳回 [start_indices[batch_index]]
  • 針對 axes(operand) 中的 d_operand
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) 如果為 d_operand = start_index_map[d_start]
    • 否則傳回 full_start_index[d_operand] = 0
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN],其中 oioffset_index 中的個別元素,而 0 會插入 collapsed_slice_dims 的索引。
  • operand_index = full_start_index + full_offset_index

如果 indices_are_sortedtrue,實作會假設 start_indices 會依 start_index_map 排序,否則行為未定義。更正式地是,針對 indices(result)full_start_index(i1) <= full_start_index(i2) 的所有 i1 < i2

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C7)、(C10-C12)、(C14)
(I2)。 start_indices 整數類型的張量 (C2)、(C3)、(C13)
(I3)。 offset_dims si64 類型的 1 維張量常數 (C1)、(C4-C5)、(C13)
(I4)。 collapsed_slice_dims si64 類型的 1 維張量常數 (C1)、(C6-C8)、(C13)
(I5) start_index_map si64 類型的 1 維張量常數 (C3)、(C9)、(C10)
(I6)。 index_vector_dim si64 類型的常數 (C2)、(C3)、(C13)
(I7) slice_sizes si64 類型的 1 維張量常數 (C8)、(C11-C13)
(I8) indices_are_sorted i1 類型的常數

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C5)、(C13-C14)

限制

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7) 0 <= collapsed_slice_dims < rank(operand)
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1
  • (C9) is_unique(start_index_map)
  • (C10) 0 <= start_index_map < rank(operand)
  • (C11) size(slice_sizes) = rank(operand)
  • (C12) 0 <= slice_sizes <= shape(operand)
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes),其中:
    • batch_dim_sizes = shape(start_indices),但資料不會包含與 index_vector_dim 對應的 start_indices 尺寸大小。
    • offset_dim_sizes = shape(slice_sizes),但不會包含 slice_sizes 中與 collapsed_slice_dims 對應的尺寸大小。
    • combine 會將 batch_dim_sizes 置於與 batch_dimsoffset_dim_sizes 對應的軸上,並放在與 offset_dims 相對應的軸。
  • (C14) element_type(operand) = element_type(result)

範例

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

更多範例

get_dimension_size

語義

產生 operand 的指定 dimension 大小。更正式的為 result = dim(operand, dimension)

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量 (C1)。
(I2)。 dimension si64 類型的常數 (C1)。

輸出

名稱 類型
result si32 類型的 0 維張量

限制

  • (C1) 0 <= dimension < rank(operand)

範例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

更多範例

get_tuple_element

語義

擷取 operand 元組 index 位置的元素,並產生 result。正式的是 result = operand[index]

輸入內容

標籤 名稱 類型 限制
(I1)。 operand tuple (C1)、(C2)
(I2)。 index si32 類型的常數 (C1)、(C2)

輸出

名稱 類型 限制
result 任何支援的類型 (C2)

限制

  • (C1) 0 <= index < size(operand)
  • (C2) type(result) = tuple_element_types(operand)[index]

範例

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

更多範例

if

語義

根據 pred 的值,從 true_branchfalse_branch 僅執行一個函式產生輸出內容。正式的是 result = pred ? true_branch() : false_branch()

輸入內容

標籤 名稱 類型 限制
(I1)。 pred i1 類型的 0 維張量
(I2)。 true_branch 函式 (C1-C3)
(I3)。 false_branch 函式 (C1)、(C2)

輸出

名稱 類型 限制
results 各種張量、量化張量或符記 (C3)

限制

  • (C1) input_types(true_branch) = input_types(false_branch) = []
  • (C2) output_types(true_branch) = output_types(false_branch)
  • (C3) type(results...) = output_types(true_branch)

範例

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

更多範例

想像力

語義

operand 擷取元素部分,然後產生 result 張量。更正式的做法是每個元素 ximag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)

輸出

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 定義為:
    • 如果 is_complex(operand) 則為 complex_element_type(element_type(operand))
    • 否則傳回 element_type(operand)

範例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

更多範例

動態內廣告

語義

這個外掛程式能讀取動態饋給中的資料,並產生 results

infeed_config 的語意是實作定義。

results 包含最先出現的酬載值和最後的符記。我們日後打算將酬載和權杖分成兩個不同的輸出內容,以便清楚呈現。(#670)。

輸入內容

標籤 名稱 類型
(I1)。 token token
(I2)。 infeed_config string 類型的常數

輸出

名稱 類型 限制
results 各種張量、量化張量或符記 (C1-C3)

限制

  • (C1) 0 < size(results)
  • (C2) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C3) is_token(type(results[-1]))

範例

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

更多範例

Iota

語義

output 張量填入 iota_dimension 維度,從零開始遞增的值。更正式的

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

輸入內容

標籤 名稱 類型 限制
(I1)。 iota_dimension si64 (C1)。

輸出

名稱 類型 限制
output 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。

限制

  • (C1) 0 <= iota_dimension < rank(output)

範例

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

更多範例

is_finite

語義

根據元素執行元素檢查 x 中的值是否為有限 (即不是 +Inf、-Inf 或 NaN),並產生 y 張量。根據 IEEE-754 規格實作 isFinite 作業。如果是量化類型,結果一律為 true

輸入內容

標籤 名稱 類型 限制
(I1)。 x 浮點類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
y 布林類型的張量 (C1)。

限制

  • (C1) shape(x) = shape(y)

範例

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

更多範例

紀錄/記錄檔

語義

operand 張量執行元素相關對數運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 log
  • 複數:複數。
  • 量化類型:dequantize_op_quantize(log, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

更多範例

log_plus_one

語義

operand 張量執行元素相關對數加上一項運算,然後產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 logp1
  • 複數:複數的對數加上一。
  • 量化類型:dequantize_op_quantize(log_plus_one, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

更多範例

物流

語義

operand 張量執行元素相關邏輯運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 division(1, addition(1, exp(-x)))
  • 複數:複雜的邏輯。
  • 量化類型:dequantize_op_quantize(logistic, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

更多範例

地圖

語義

沿著 dimensions 將對應函式 computation 套用至 inputs,並建立 result 張量。

正式是 result[result_index] = computation(inputs...[result_index])。請注意,dimensions 目前並未使用,未來可能會移除 (#487)。

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1-C4)
(I2)。 dimensions si64 類型的 1 維張量常數 (C3)
(I3)。 computation 函式 (C4)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)、(C4)

限制

  • (C1) shape(inputs...) = shape(result)
  • (C2) 0 < size(inputs) = N
  • (C3) dimensions = range(rank(inputs[0]))
  • (C4) computation 具有 (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> 類型,其中 Ei = element_type(inputs[i])E' = element_type(result)

範例

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

更多範例

最高

語義

對張量 lhsrhs 執行元素相關最大值作業,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 OR。
  • 整數:整數最大值。
  • 浮點值:由 IEEE-754 距離 maximum
  • 複數:(real, imaginary) 組合的字母順序上限。在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。
  • 量化類型:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1)。
(I2)。 rhs 張量或每個張量化張量 (C1)。

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

更多範例

最低

語義

對張量 lhsrhs 執行元素相關最低作業,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 AND。
  • 整數:整數的最小值。
  • 浮點值:由 IEEE-754 距離 minimum
  • 複數:(real, imaginary) 組合的字母順序下限。在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。
  • 量化類型:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1)。
(I2)。 rhs 張量或每個張量化張量 (C1)。

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

更多範例

語義

執行兩個張量 lhsrhs 的元素相關乘積,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 AND。
  • 整數:整數乘法。
  • 浮點值:由 IEEE-754 距離 multiplication
  • 複數:複雜的乘法。
  • 量化類型:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 張量或每個張量化張量 (C1)。
(I2)。 rhs 張量或每個張量化張量 (C1)。

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

更多範例

negate

語義

執行 operand 張量的元素否定,並產生 result 張量。根據元素類型執行以下操作:

  • 帶正負號整數:整數否定。
  • 無正負號整數:bitcast 到帶正負號整數、整數否定,位元轉換回無正負號整數。
  • 浮點值:由 IEEE-754 距離 negate
  • 複數:複雜的否定。
  • 量化類型:dequantize_op_quantize(negate, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

更多範例

不是

語義

於「不」執行張量 operand 的元素時,產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 NOT。
  • 整數:位元 NOT。

引數

名稱 類型 限制
operand 布林值或整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(operand) = type(result)

範例

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

語義

確保產生 operand 的作業會在依附 result 的任何作業之前執行,並防止編譯器轉換作業在障礙之間移動作業。除此之外,作業則是身分 (例如 result = operand)。

引數

名稱 類型 限制
operand 各種張量、各張量化張量或符記 (C1)。

輸出

名稱 類型 限制
result 各種張量、各張量化張量或符記 (C1)。

限制

  • (C1) type(operand...) = type(result...)

範例

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

更多範例

語義

以元素或 rhs 的方式執行兩個張量 lhsrhs,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 OR。
  • 整數:位元 OR。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數或布林值類型的張量 (C1)。
(I2)。 rhs 整數或布林值類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數或布林值類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

外動態

語義

inputs 寫入外動態饋給,並產生 result 權杖。

outfeed_config 的語意是實作定義。

輸入內容

標籤 名稱 類型
(I1)。 inputs 各種張量或量化張量
(I2)。 token token
(I3)。 outfeed_config string 類型的常數

輸出

名稱 類型
result token

範例

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多範例

墊片

語義

透過在張量周圍以及具有指定的 padding_value 的張量元素之間加上邊框間距來展開 operand

edge_padding_lowedge_padding_high 分別指定在低端 (索引 0 旁邊) 和頂端 (最高索引旁邊) 新增的邊框間距量。邊框間距量可以是負數,其中負邊框間距的絕對值代表要從指定維度中移除的元素數量。

interior_padding 會指定在每個維度的任意兩個元素之間加入的邊框間距量,但不得為負數。內部邊框間距發生在邊緣邊框間距之前,因此負的邊緣邊框間距會從內部填充運算元中移除元素。

更正式的,result[result_index] 定義為:

  • 如果為 result_index = edge_padding_low + operand_index * (interior_padding + 1),則為 operand[operand_index]
  • 否則傳回 padding_value

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C2)、(C4)
(I2)。 padding_value 0 維張量或按張量化張量 (C1)。
(I3)。 edge_padding_low si64 類型的 1 維張量常數 (C1)、(C4)
(I4)。 edge_padding_high si64 類型的 1 維張量常數 (C1)、(C4)
(I5) interior_padding si64 類型的 1 維張量常數 (C2-C4)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C3-C6)

限制

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

範例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多範例

partition_id

語義

產生目前程序的 partition_id

輸出

名稱 類型
result ui32 類型的 0 維張量

範例

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

更多範例

罌粟紅

語義

針對 operand 張量中設定的位元數執行元素相關位元數,並產生 result 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數類型的張量 (C1)。

限制

  • (C1) type(operand) = type(result)

範例

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

更多範例

功率

語義

透過 rhs 張量執行 lhs 張量的元素指數,並產生 result 張量。根據元素類型執行以下操作:

  • 整數:整數指數。
  • 浮點值:由 IEEE-754 距離 pow
  • 複數:複數的指數。
  • 量化類型:dequantize_op_quantize(power, lhs, rhs, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。
(I2)。 rhs 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

更多範例

real

語義

operand 逐元素擷取實際部分,並產生 result 張量。更正式的做法是每個元素 xreal(x) = is_complex(x) ? real_part(x) : x

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點或複雜類型的張量 (C1)、(C2)

輸出

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 定義為:
    • 如果 is_complex(operand) 則為 complex_element_type(element_type(operand))
    • 否則傳回 element_type(operand)

範例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

更多範例

Recv

語義

接收含有 channel_id 的頻道資料,並產生 results

如果 is_host_transfertrue,則作業會從主機移轉資料。否則,系統會轉移其他裝置中的資料。這代表的是實作定義。這個旗標會複製 channel_type 中提供的資訊,因此未來我們打算只會保留其中一個 (#666)。

results 包含最先出現的酬載值和最後的符記。我們日後打算將酬載和權杖分成兩個不同的輸出內容,以便清楚呈現。(#670)。

輸入內容

標籤 名稱 類型 限制
(I1)。 token token (C4)
(I2)。 channel_id si64 類型的常數
(I3)。 channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE 的列舉項目 (C1)。
(I4)。 is_host_transfer i1 類型的常數 (C1)。

輸出

名稱 類型 限制
results 各種張量、量化張量或符記 (C2-C4)

限制

  • (C1) channel_type 定義為:
    • HOST_TO_DEVICE 表示 is_host_transfer = true
    • 否則傳回 DEVICE_TO_DEVICE
  • (C2) 0 < size(results)
  • (C3) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C4) is_token(type(results[-1]))

範例

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

更多範例

減少

語義

將縮減函式 body 沿著 dimensions 套用至 inputsinit_values,並產生 results 張張。

縮減順序是由實作定義,這表示 bodyinit_values 必須形成單聲道屬性,以確保該作業在所有實作項目上,都會產生相同的結果。然而,這項條件並不適用於許多常見的縮減項目。例如,body 的浮點加至 init_values 與 0 實際上並未形成單聲道,因為浮點加上沒有關聯。

更正式的 results...[j0, ..., jR-1] = reduce(input_slices_converted),其中:

  • input_slices = inputs...[j0, ..., :, ..., jR-1],其中 : 會在 dimensions 插入。
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) 代表二進位樹狀結構 schedule,其中:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule 是實作定義的完整二進位樹狀結構,其按順序週遊包含:
    • input_slices_converted...[index] 值,適用於 index_space(input_slices_converted) 中所有 index,並依 index 的字母順序排列。
    • 在實作定義的位置,會與實作定義的 init_values_converted 數量交錯。

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1-C4)、(C6)、(C7)
(I2)。 init_values 各種 0 維張量或每個張量量化張量的數量 (C2)、(C3)
(I3)。 dimensions si64 類型的 1 維張量常數 (C4)、(C5)、(C7)
(I4)。 body 函式 (C6)

輸出

名稱 類型 限制
results 各種張量或每個張量量化張量 (C3)、(C7)、(C8)

限制

  • (C1) same(shape(inputs...))
  • (C2) element_type(inputs...) = element_type(init_values...)
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C4) 0 <= dimensions < rank(inputs[0])
  • (C5) is_unique(dimensions)
  • (C6) body 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C7) shape(results...) = shape(inputs...),但不含與 dimensions 對應的 inputs... 尺寸大小。
  • (C8) [0,N) 中所有 ielement_type(results[i]) = Ei

範例

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

更多範例

reduce_precision

語義

operand 的元素導向元素轉換,使用 exponent_bitsmantissa_bits 並回到原始浮點類型,然後產生 output 張量。

更正式:

  • 原始值的 mantisa 位元會更新,使用 mantissa_bits 語意將原始值四捨五入至最接近的值。 :roundToIntegralTiesToEven
  • 然後,如果 mantissa_bits 小於原始值的 mantisa 位元數,系統會將 mantisa 位元截斷為 mantissa_bits
  • 接著,如果中繼結果的指數位元不符合 exponent_bits 提供的範圍,則中繼結果會使用原始記號溢位至無限,或使用原始符號後溢位至零。
  • 如果是量化類型,請執行 dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。
(I2)。 exponent_bits si32 類型的常數 (C2)
(I3)。 mantissa_bits si32 類型的常數 (C3)

輸出

名稱 類型 限制
output 浮點類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(output)
  • (C2) 1 <= exponent_bits
  • (C3) 0 <= mantissa_bits

範例

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

更多範例

reduce_scatter

語義

在 StableHLO 程序格線中的每個程序群組中,使用 computations 對每個程序的 operand 張量值執行縮減,將縮減結果沿著 scatter_dimension 分割成部分,並將分割部分分散到不同程序之間以產生 result

這項作業會將 StableHLO 程序格線分割為 process_groups,定義如下:

  • cross_replica(replica_groups) 如果為 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果為 channel_id > 0 and use_global_device_ids = true

之後,在每個 process_group 中:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • process_group 中所有 senderresult@receiver = parts@sender[receiver_index],其中 receiver_index = process_group.index(receiver)

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C2)、(C7)、(C8)
(I2)。 scatter_dimension si64 類型的常數 (C1)、(C2)、(C8)
(I3)。 replica_groups si64 類型的 2 維張量常數 (C3-C5)
(I4)。 channel_id si64 類型的常數 (C6)
(I5) use_global_device_ids i1 類型的常數 (C6)
(I6)。 computation 函式 (C7)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C8-C9)

限制

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2) 0 <= scatter_dimension < rank(operand)
  • (C3) is_unique(replica_groups)
  • (C4) size(replica_groups) 定義為:
    • 如果使用 cross_replica,則為 num_replicas
    • 如果使用 cross_replica_and_partition,則為 num_replicas
    • 如果使用 flattened_ids,則為 num_processes
  • (C5) 0 <= replica_groups < size(replica_groups)
  • (C6) 如為 use_global_device_ids = true,則預設為 channel_id > 0
  • (C7) computation 具有類型 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C8) shape(result) = shape(operand),但下列項目除外:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E

範例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

更多範例

reduce_window

語義

將縮減函式 body 套用至 inputsinit_values 的視窗,並產生 results

下圖以具體範例說明系統如何從 inputs... 計算 results... 中的元素。

正式來說,results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (請參閱 reduce) 其中:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2)。 init_values 各種 0 維張量或每個張量量化張量的數量 (C1)、(C13)
(I3)。 window_dimensions si64 類型的 1 維張量常數 (C4)、(C5)、(C15)
(I4)。 window_strides si64 類型的 1 維張量常數 (C6)、(C7)、(C15)
(I5) base_dilations si64 類型的 1 維張量常數 (C8)、(C9)、(C15)
(I6)。 window_dilations si64 類型的 1 維張量常數 (C10)、(C11)、(C15)
(I7) padding si64 類型的 2 維張量常數 (C12)、(C15)
(I8) body 函式 (C13)

輸出

名稱 類型 限制
results 各種張量或每個張量量化張量 (C1)、(C14-C16)

限制

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C2) same(shape(inputs...))
  • (C3) element_type(inputs...) = element_type(init_values...)
  • (C4) size(window_dimensions) = rank(inputs[0])
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(inputs[0])
  • (C7) 0 < window_strides
  • (C8) size(base_dilations) = rank(inputs[0])
  • (C9) 0 < base_dilations
  • (C10) size(window_dilations) = rank(inputs[0])
  • (C11) 0 < window_dilations
  • (C12) shape(padding) = [rank(inputs[0]), 2]
  • (C13) body 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C14) same(shape(results...))
  • (C15) shape(results[0]) = num_windows,其中:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei 適用於 [0,N) 中的所有 i

範例

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

更多範例

剩餘

語義

執行被除 lhs 和除數 rhs 張量的元素相關餘數,然後產生 result 張量。

更正式的,表示結果的正負號從被除數取下,而結果的絕對值一律小於除數的絕對值。餘數的計算方式為 lhs - d * rhs,其中 d 為:

  • 整數:stablehlo.divide(lhs, rhs)
  • 浮點值:來自 IEEE-754 的 division(lhs, rhs),帶有捨入屬性 roundTowardZero
  • 如為複數:未定 (#997)。
  • 量化類型:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

就浮點元素類型而言,這個運算與 IEEE-754 規格的 remainder 運算有差異,其中 d 是最接近 lhs/rhs 值的整數值,而且相對值相等。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。
(I2)。 rhs 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。

輸出

名稱 類型 限制
result 整數、浮點、複雜類型,或每個張量的量化張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

更多範例

replica_id

語義

產生目前程序的 replica_id

輸出

名稱 類型
result ui32 類型的 0 維張量

範例

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

更多範例

重塑

語義

執行將 operand 張量重新塑形為 result 張量。從概念上來說,最好維持相同的標準表示法,但可能需要改變形狀,例如從 tensor<2x3xf32> 變更為 tensor<3x2xf32>tensor<6xf32>

一般而言,result[result_index] = operand[operand_index],其中 result_indexoperand_index 按照 index_space(result)index_space(operand) 的字母順序排列相同的位置。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或量化張量 (C1-C3)

輸出

名稱 類型 限制
result 張量或量化張量 (C1-C3)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand) (如果為 !is_per_axis_quantized(operand))。
    • element_type(operand) 除外,但 quantization_dimension(operand)quantization_dimension(result) 可能不同。
  • (C2) size(operand) = size(result)
  • (C3) 如為 is_per_axis_quantized(operand)
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

範例

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

更多範例

反向排序

語義

依照指定的 dimensions 反轉 operand 中的元素順序,並產生 result 張量。更正式的 result[result_index] = operand[operand_index],其中:

  • 如果 dimensions 中的 d,則為 operand_index[d] = dim(result, d) - result_index[d] - 1
  • 否則傳回 operand_index[d] = result_index[d]

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1)、(C3)
(I2)。 dimensions si64 類型的 1 維張量常數 (C2)、(C3)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)、(C3)

限制

  • (C1) type(operand) = type(result)
  • (C2) is_unique(dimensions)
  • (C3) 0 <= dimensions < rank(result)

範例

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

更多範例

奈米

語義

使用 rng_distribution 演算法產生隨機號碼,並產生指定形狀 shaperesult 張量。

如果為 rng_distribution = UNIFORM,則隨機數會根據間隔 [a, b) 的統一分佈產生。如果為 a >= b,表示行為未定義。

如果為 rng_distribution = NORMAL,則隨機數字會在常態分佈後產生,平均值為 a,標準差 = b。如果為 b < 0,表示行為未定義。

隨機號碼的產生方式確切的產生方式是由實作定義。例如,不一定有確定性,而且不一定使用隱藏狀態。

在與許多利害關係人的對話中,這項運算實際上已確實淘汰,因此日後我們計劃設法移除這個功能 (#597)。

輸入內容

標籤 名稱 類型 限制
(I1)。 a 整數、布林值或浮點類型的 0 維張量 (C1)、(C2)
(I2)。 b 整數、布林值或浮點類型的 0 維張量 (C1)、(C2)
(I3)。 shape si64 類型的 1 維張量常數 (C3)
(I4)。 rng_distribution UNIFORMNORMAL 的列舉項目 (C2)

輸出

名稱 類型 限制
result 整數、布林值或浮點類型的張量 (C1-C3)

限制

  • (C1) element_type(a) = element_type(b) = element_type(result)
  • (C2) 如為 rng_distribution = NORMAL,則預設為 is_float(a)
  • (C3) shape(result) = shape

範例

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

語義

使用虛擬隨機號碼產生器演算法 rng_algorithm,為初始狀態 initial_state 傳回已填入統一隨機位元的 output,並更新輸出狀態 output_state。保證提供 initial_state 的確定性函式,但不保證在實作之間具確定性。

rng_algorithm 是下列其中一項:

  • DEFAULT:導入定義的演算法。
  • THREE_FRY:導入 Threefry 演算法定義的變化版本*。
  • PHILOX:實作定義的 Philox 演算法變化版本*。

* 請參閱:Salmon 等人,SC 2011。平行隨機數:和 1、2、3 一樣簡單。

輸入內容

標籤 名稱 類型 限制
(I1)。 rng_algorithm DEFAULTTHREE_FRYPHILOX 的列舉項目 (C2)
(I2)。 initial_state ui64 類型的 1 維張量 (C1)、(C2)

輸出

名稱 類型 限制
output_state ui64 類型的 1 維張量 (C1)。
output 整數或浮點類型的張量

限制

  • (C1) type(initial_state) = type(output_state)
  • (C2) size(initial_state) 定義為:
    • rng_algorithm = DEFAULT 時定義。
    • 如果 rng_algorithm = THREE_FRY 則為 2
    • 23 表示 rng_algorithm = PHILOX

範例

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

語義

針對 operand 張量執行元素相關進位法,將元素四捨五入至最接近的整數,從零打斷,並產生 result 張量。從 IEEE-754 規格實作 roundToIntegralTiesToAway 作業。如果是量化類型,請執行 dequantize_op_quantize(round_nearest_afz, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

更多範例

round_nearest_even

語義

針對 operand 張量執行元素相關進位法,將元素四捨五入至最接近的整數,並將鏈結至偶數整數,並產生 result 張。從 IEEE-754 規格實作 roundToIntegralTiesToEven 作業。如果是量化類型,請執行 dequantize_op_quantize(round_nearest_even, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

更多範例

rsqrt

語義

operand 張量執行元素相關平方根運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 rSqrt
  • 如為複數:複雜的倒數平方根。
  • 量化類型:dequantize_op_quantize(rsqrt, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

更多範例

scatter

語義

產生等於 inputs 張量的 results 張量,但 scatter_indices 指定的幾個切片會使用 update_computation 更新為 updates 的值。

下圖以具體範例說明 updates... 中的元素如何對應到 results... 中的元素。圖中挑選了一些 updates... 索引範例,並詳細說明這些 results... 代表對應的索引。

更正式地是,針對 index_space(updates[0]) 中的所有 update_index

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index 定義為:
    • scatter_indices[si0, ..., :, ..., siN],其中 siupdate_scatter_index 中的個別元素,如果 index_vector_dim < rank(scatter_indices),則 : 會插入 index_vector_dim 索引。
    • 否則傳回 [scatter_indices[update_scatter_index]]
  • 針對 axes(inputs[0]) 中的 d_input
    • 如果為 d_input = scatter_dims_to_operand_dims[d_start],則為 full_start_index[d_input] = start_index[d_start]
    • 否則傳回 full_start_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN],其中 wiupdate_window_index 中的個別元素,而 0 會插入 inserted_window_dims 的索引。
  • result_index = full_start_index + full_window_index.

因此,results = exec(schedule, inputs),其中:

  • scheduleindex_space(updates[0]) 的實作定義的排列方式。
  • exec([update_index, ...], results) = exec([...], updated_results),其中:
    • 如果 result_indexshape(results...) 的邊界內
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_resultsresults 的副本,並將 results...[result_index] 設為 updated_values...
    • 你也可以
    • updated_results = results.
  • exec([], results) = results.

如果 indices_are_sortedtrue,實作會假設 scatter_indices 會依照 scatter_dims_to_operand_dims 排序,否則行為未定義。更正式地是,針對 indices(result) 中的所有 i1 < i2full_start_index(i1) <= full_start_index(i2)

如果 unique_indicestrue,則實作會假設所有 result_index 索引均散發不重複。如果 unique_indicestrue,但索引被散佈為重複,表示行為未定義。

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1)、(C2)、(C4-C6)、(C10)、(C13)、(C15-C16)
(I2)。 scatter_indices 整數類型的張量 (C4)、(C11)、(C14)
(I3)。 updates 各種張量或每個張量量化張量 (C3-C6)、(C8)
(I4)。 update_window_dims si64 類型的 1 維張量常數 (C2)、(C4)、(C7)、(C8)
(I5) inserted_window_dims si64 類型的 1 維張量常數 (C2)、(C4)、(C9)、(C10)
(I6)。 scatter_dims_to_operand_dims si64 類型的 1 維張量常數 (C11-C13)
(I7) index_vector_dim si64 類型的常數 (C4)、(C11)、(C14)
(I8) indices_are_sorted i1 類型的常數
(I9) unique_indices i1 類型的常數
(I10)。 update_computation 函式 (C15)

輸出

名稱 類型 限制
results 各種張量或每個張量量化張量 (C15-C17)

限制

  • (C1) same(shape(inputs...))
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
  • (C3) same(shape(updates...))
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes),其中:
    • update_scatter_dim_sizes = shape(scatter_indices),但資料不會包含與 index_vector_dim 對應的 scatter_indices 尺寸大小。
    • update_window_dim_sizes <= shape(inputs[0]),但不會包含 inputs[0] 中與 inserted_window_dims 對應的尺寸大小。
    • combine 會將 update_scatter_dim_sizes 置於與 update_scatter_dimsupdate_window_dim_sizes 對應的軸上,並放在與 update_window_dims 對應的軸。
  • (C5) 0 < size(inputs) = size(updates) = N
  • (C6) element_type(updates...) = element_type(inputs...)
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8) 0 <= update_window_dims < rank(updates[0])
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims)
  • (C10) 0 <= inserted_window_dims < rank(inputs[0])
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
  • (C12) is_unique(scatter_dims_to_operand_dims)
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices)
  • (C15) update_computation 具有 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) 類型,其中 is_promotable(element_type(inputs[i]), Ei)
  • (C16) shape(inputs...) = shape(results...)
  • (C17) element_type(results[i]) = Ei 適用於 [0,N) 的所有 i

範例

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

更多範例

選取

語義

產生 result 張量,並根據 pred 的對應元素值,從 on_trueon_false 張量選取每個元素。正式名稱為 result[result_index] = pred_element ? on_true[result_index] : on_false[result_index],其中 pred_element = rank(pred) = 0 ? pred[] : pred[result_index]。如果是量化類型,請執行 dequantize_select_quantize(pred, on_true, on_false, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 pred i1 類型的張量 (C1)。
(I2)。 on_true 張量或每個張量化張量 (C1-C2)
(I3)。 on_false 張量或每個張量化張量 (C2)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C2)

限制

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

範例

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

更多範例

select_and_scatter

語義

使用 select 根據 input 張量的 reduce_window 範圍,使用 scatter 分散 source 張量的值,然後產生 result 張量。

下圖顯示如何使用具體範例,從 operandsource 計算 result 中的元素。

更正式:

  • 含有下列輸入內容的 selected_values = reduce_window_without_init(...)

    • `inputs = [運算元]。
    • 依原樣使用 window_dimensionswindow_stridespadding
    • base_dilations = windows_dilations = 1.
    • body 定義為:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    其中 E = element_type(operand)reduce_window_without_init 的運作方式與 reduce_window 類似,但基礎 reduceschedule (請參閱 reduce) 不包含 init 值。目前未指出如果對應的視窗沒有值,會發生什麼情況。(#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter),其中:

    • source_values = [source[source_index] for source_index in source_indices].
    • 如果 selected_values[source_index] 具有來自 operand_indexoperand 元素,則為 selected_index(source_index) = operand_index
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1-C4)、(C6)、(C8-C11)
(I2)。 source 張量或每個張量化張量 (C1)、(C2)
(I3)。 init_value 0 維張量或按張量化張量 (C3)
(I4)。 window_dimensions si64 類型的 1 維張量常數 (C2)、(C4)、(C5)
(I5) window_strides si64 類型的 1 維張量常數 (C2)、(C6)、(C7)
(I6)。 padding si64 類型的 2 維張量常數 (C2)、(C8)
(I7) select 函式 (C9)
(I8) scatter 函式 (C10)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C11-C12)

限制

  • (C1) element_type(operand) = element_type(source)
  • (C2) shape(source) = num_windows,其中:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand)
  • (C4) size(window_dimensions) = rank(operand)
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(operand)
  • (C7) 0 < window_strides
  • (C8) shape(padding) = [rank(operand), 2]
  • (C9) select 具有類型 (tensor<E>, tensor<E>) -> tensor<i1>,其中 E = element_type(operand)
  • (C10) scatter 具有類型 (tensor<E>, tensor<E>) -> tensor<E>,其中 is_promotable(element_type(operand), E)
  • (C11) shape(operand) = shape(result)
  • (C12) element_type(result) = E

範例

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

更多範例

傳送

語義

inputs 傳送至管道 channel_id,並產生 result 權杖。

如果 is_host_transfertrue,則作業會將資料轉移至主機。否則,系統會將資料轉移到其他裝置。這代表的是實作定義。這個旗標會複製 channel_type 中提供的資訊,因此未來我們打算只會保留其中一個 (#666)。

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或量化張量
(I2)。 token token
(I3)。 channel_id si64 類型的常數
(I4)。 channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST 的列舉項目 (C1)。
(I5) is_host_transfer i1 類型的常數 (C1)。

輸出

名稱 類型
result token

限制

  • (C1) channel_type 定義為:
    • DEVICE_TO_HOST 表示 is_host_transfer = true
    • 否則傳回 DEVICE_TO_DEVICE

範例

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多範例

shift_left

語義

rhs 位元數對 lhs 張量執行元素相關左移作業,並產生 result 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數類型的張量 (C1)。
(I2)。 rhs 整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

更多範例

shift_right_arithmetic

語義

以位元數lhs張數,rhs張數對元素做出元素右移運算,然後產生 result張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數類型的張量 (C1)。
(I2)。 rhs 整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

更多範例

shift_right_logical

語義

lhs 張量執行 rhs 張量的元素右移邏輯,然後產生 result 張量。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數類型的張量 (C1)。
(I2)。 rhs 整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

更多範例

簽署

語義

依序傳回 operand 元素的符號,並產生 result 張量。 更正式地說,每個元素 x 都能使用 Python 語法來表示語意,如下所示:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

如果是量化類型,請執行 dequantize_op_quantize(sign, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

更多範例

正弦

語義

operand 張量執行元素相關正弦運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 sin
  • 複數:複數的正弦。
  • 量化類型:dequantize_op_quantize(sine, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

更多範例

配量

語義

使用靜態計算的啟動索引從 operand 擷取切片,然後產生 result 張量。start_indices 包含每個維度的切片起始索引,limit_indices 包含每個維度的切片的結尾索引 (不含),而 strides 則包含每個維度的步距。

正式結果是 result[result_index] = operand[operand_index],其中 operand_index = start_indices + result_index * strides

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或每個張量化張量 (C1-C3)、(C5)
(I2)。 start_indices si64 類型的 1 維張量常數 (C2)、(C3)、(C5)
(I3)。 limit_indices si64 類型的 1 維張量常數 (C2)、(C3)、(C5)
(I4)。 strides si64 類型的 1 維張量常數 (C2)、(C4)

輸出

名稱 類型 限制
result 張量或每個張量化張量 (C1)、(C5)

限制

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand)
  • (C4) 0 < strides
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides)

範例

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

更多範例

排序

語義

根據 comparator 並按維度 dimension,一併排序 inputs 的 1 維切片,並產生 results

與其他作業中的類似輸入內容不同,dimension 允許負值,但使用下述語意。日後,基於一致性原因,可能會禁止這項操作。(#1377)。

如果 is_stable 為 true,排序功能就會穩定,也就是視為與比較元素相等的元素相對順序。如果有多個輸入,e1e2 兩個元素會視為相等,只有在 comparator(e1, e2) = comparator(e2, e1) = false 時才會被視為相等。請參閱下方的正規化以如何歸納為多個輸入內容。

更正式地是,針對 index_space(results[0]) 中的所有 result_index

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1],其中 riNresult_index 中的個別元素,而 : 會在 adjusted_dimension 插入。
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • 其中 sort 會以非遞減順序排序 1 維配量,以預期在左側引數小於右側第二個引數時,comparator_together 會傳回 true
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

輸入內容

標籤 名稱 類型 限制
(I1)。 inputs 各種張量或每個張量量化張量 (C1-C5)
(I2)。 dimension si64 類型的常數 (C4)
(I3)。 is_stable i1 類型的常數
(I4)。 comparator 函式 (C5)

輸出

名稱 類型 限制
results 各種張量或每個張量量化張量 (C2)、(C3)

限制

  • (C1) 0 < size(inputs)
  • (C2) type(inputs...) = type(results...)
  • (C3) same(shape(inputs...) + shape(results...))
  • (C4) -R <= dimension < R,其中 R = rank(inputs[0])
  • (C5) comparator 具有 (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1> 類型,其中 Ei = element_type(inputs[i])

範例

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

更多範例

sqrt

語義

operand 張量執行元素層級平方根運算,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 squareRoot
  • 複數:複數的平方根。
  • 量化類型:dequantize_op_quantize(sqrt, operand, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

更多範例

subtract

語義

對元素執行元素減法將兩個張量 lhsrhs 相減,然後產生 result 張量。根據元素類型執行以下操作:

  • 整數:整數減。
  • 浮點值:由 IEEE-754 距離 subtraction
  • 複數:複數的減法。
  • 量化類型:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。
(I2)。 rhs 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 整數、浮點、複雜類型或每個張量量化張量的張量 (C1)。

限制

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

範例

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

更多範例

坦赫

語義

operand 張量執行元素層級的雙曲線正切作業,並產生 result 張量。根據元素類型執行以下操作:

  • 浮點值:由 IEEE-754 距離 tanh
  • 如果數字為複數:複雜的雙曲正切。
  • 量化類型:
    • dequantize_op_quantize(tanh, operand, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點、複雜類型或每個張量化張量的張量 (C1)。

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_type(operand) = baseline_type(result)

範例

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

更多範例

轉置

語義

使用 permutation 保留 operand 張量的維度,並產生 result 張量。更正式的,也就是 result[result_index] = operand[operand_index],其中 result_index[d] = operand_index[permutation[d]]

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 張量或量化張量 (C1-C4)
(I2)。 permutation si64 類型的 1 維張量常數 (C2-C4)

輸出

名稱 類型 限制
result 張量或量化張量 (C1)、(C3-C4)

限制

  • (C1) element_type(result) 的提供者:
    • element_type(operand) (如果為 !is_per_axis_quantized(operand))。
    • element_type(operand) 除外,但 quantization_dimension(operand)quantization_dimension(result) 可能不同。
  • (C2) permutationrange(rank(operand)) 的排列方式。
  • (C3) shape(result) = dim(operand, permutation...)
  • (C4) 如果為 is_per_axis_quantized(result),則 quantization_dimension(operand) = permutation(quantization_dimension(result))

範例

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

更多範例

triangular_solve

語義

以較低或上三角係數來解開聯立線性方程式的批次。

更正式地說,由於 ab,當 left_sidetrue 時,result[i0, ..., iR-3, :, :]op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] 的解決方案,當 left_sidefalse 時,則解決 op(a) 變數的 x 解決方案,其中 op(a) 可為 transpose_a,可以是下列其中一項:x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]

  • NO_TRANSPOSE:依原樣使用 a 執行作業。
  • TRANSPOSE:對轉置 a 執行作業。
  • ADJOINT:對 a 的共錐轉出執行作業。

系統只會在 lowertruea 的上三角形時,從 a 的底部三角形讀取輸入資料。輸出資料會在同一個三角形中傳回,其他三角形的值則是實作定義。

如果 unit_diagonal 為 true,實作會假設 a 的對角元素等於 1,否則行為未定義。

如果是量化類型,請執行 dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))

輸入內容

標籤 名稱 類型 限制
(I1)。 a 浮點、複雜類型或每個張量化張量的張量 (C1-C3)
(I2)。 b 浮點、複雜類型或每個張量化張量的張量 (C1-C4)
(I3)。 left_side i1 類型的常數 (C3)
(I4)。 lower i1 類型的常數
(I5) unit_diagonal i1 類型的常數
(I6)。 transpose_a NO_TRANSPOSETRANSPOSEADJOINT 的列舉項目

輸出

名稱 類型 限制
result 浮點、複雜類型或每個張量化張量的張量 (C1)。

限制

  • (C1) baseline_element_type(a) = baseline_element_type(b)
  • (C2) 2 <= rank(a) = rank(b) = R
  • (C3) shape(a)shape(b) 之間的關係定義如下:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result)

範例

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

語義

使用值 val 產生 result 元組。

輸入內容

標籤 名稱 類型 限制
(I1)。 val 變異數值 (C1)。

輸出

名稱 類型 限制
result tuple (C1)。

限制

  • (C1) result 屬於 tuple<E0, ..., EN-1> 類型,其中 Ei = type(val[i])

範例

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

更多範例

uniform_dequantize

語義

根據 operand 類型定義的量化參數,將量化張量 operand 的元素層級轉換為浮點張量 result

正式的是 result = dequantize(operand)

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 量化張量 (C1)、(C2)

輸出

名稱 類型 限制
result 浮點類型的張量 (C1)、(C2)

限制

  • (C1) shape(operand) = shape(result)
  • (C2) element_type(result) = expressed_type(operand)

範例

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

語義

根據 result 類型定義的量化參數,將浮點張量或量化張量 operand 轉換成量化張量 result 的元素層級轉換。

更正式的

  • 如果 is_float(operand)
    • result = quantize(operand, type(result)).
  • 如果 is_quantized(operand)
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 浮點或量化類型的張量 (C1)、(C2)

輸出

名稱 類型 限制
result 量化張量 (C1)、(C2)

限制

  • (C1) shape(operand) = shape(result)
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

範例

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

語義

cond 函式輸出 true 時,執行 body 函式 0 次以上會產生輸出內容。正式的語意可透過 Python 語法表示,如下所示:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

無限迴圈的行為是待定。(#383)。

輸入內容

標籤 名稱 類型 限制
(I1)。 operand 各種張量、量化張量或符記 (C1-C3)
(I2)。 cond 函式 (C1)。
(I3)。 body 函式 (C2)

輸出

名稱 類型 限制
results 各種張量、量化張量或符記 (C3)

限制

  • (C1) cond 具有 (T0, ..., TN-1) -> tensor<i1> 類型,其中 Ti = type(operand[i])
  • (C2) body 具有 (T0, ..., TN-1) -> (T0, ..., TN-1) 類型,其中 Ti = type(operand[i])
  • (C3) type(results...) = type(operand...)

範例

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

更多範例

xor

語義

執行兩個張量 lhsrhs 的元素層級 XOR,並產生 result 張量。根據元素類型執行以下操作:

  • 布林值:邏輯 XOR。
  • 整數:位元 XOR。

輸入內容

標籤 名稱 類型 限制
(I1)。 lhs 布林值或整數類型的張量 (C1)。
(I2)。 rhs 布林值或整數類型的張量 (C1)。

輸出

名稱 類型 限制
result 布林值或整數類型的張量 (C1)。

限制

  • (C1) type(lhs) = type(rhs) = type(result)

範例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

執行

依序執行

系統會透過提供輸入值給 main 函式並計算輸出值來執行 StableHLO 程式。系統會執行對應 return 運算中的運算圖形,來計算函式的輸出值。

只要執行順序與 Dataflow 一致,系統就會實作執行順序,即作業是在使用前執行。在 StableHLO 中,所有連帶效果運算都會使用一個符記並產生一個符記 (可透過 after_all 將多個符記加倍為一個符記),因此副作用的執行順序也會與 Dataflow 保持一致。上述範例程式可能的執行順序為 %0%1%2%3%4return%3%0%2%4return%1

更正式的,StableHLO 程序結合了:1) StableHLO 程式、2) 作業狀態 (尚未執行、已經執行) 和 3) 程序處理的中繼值。程序從 main 函式的輸入值開始,逐步查看運算狀態和中間值的更新作業,並結束輸出值。最終正規化為待定。(#484)。

平行執行

可平行執行 StableHLO 程式,並依 num_partitions 整理為 num_replicas 的 2D 程序格線,兩者都具有類型 ui32

StableHLO 程序格線中,有 num_replicas * num_partitions 個 StableHLO 程序會同時執行。每個程序都有專屬的 process_id = (replica_id, partition_id),其中 replica_ids = range(num_replicas)partition_ids = range(num_partitions) 中的 replica_id,兩者類型皆為 ui32partition_id

每個程式都會以靜態方式得知程序格線的大小 (未來,我們計劃將其使其成為 StableHLO 程式 #650 中的明確部分),而且每個程序在程序格線中的位置都是靜態的。每個程序都可以透過 replica_idpartition_id 運算存取程序格線內的位置。

在程序格線中,所有程式皆可相同 (在「單一程式,多重資料」樣式中),每一種 (在「多重計畫」、「多重資料」樣式中) 皆可相同,或者也可以相同。我們預計日後會支援定義平行 StableHLO 程式的其他慣用語,包括 GSPMD (#619)。

在程序格線中,程序主要彼此獨立 - 程序具有不同的作業狀態、獨立的輸入/中間/輸出值,以及大多數運算,會在程序之間分開執行,但下文所述的少數集體運算除外。

由於大部分的運算只會使用相同程序的值,因此使用這些值的名稱通常並不明確。不過,在描述共同運算的語意時,會不足並引發標記 name@process_id 參照特定程序中的值 name。(在這種情況下,不符資格的 name 可視為 name@(replica_id(), partition_id()) 的簡寫)。

程序之間的執行順序是由實作定義,但點對點通訊與集體運算而引入的同步處理除外 (如下所述)。

點對點通訊

StableHLO 程序可透過 StableHLO 管道相互通訊。管道會以 si64 類型的正 ID 表示。透過各種作業,您可以將值傳送至管道並從管道接收。

進一步的正式化,例如管道 ID 來源、程序如何得知程序,以及所導入的同步處理類型 (#484)。

串流通訊

每個 StableHLO 程序都可存取兩種串流介面:

  • 可以讀取的動態內廣告
  • 可寫入的 Outfeed

與管道不同,管道用於在程序之間進行通訊,因此在兩者的兩端都有程序,而動態內和外動態則由其他的實作定義。

如果需要正式化,例如串流通訊對執行順序的影響,以及其執行何種同步處理作業,均未定案。(#484)。

集體作業

StableHLO 中有六項共同作業:all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter。這些作業會將 StableHLO 程序中的程序拆分為 StableHLO 程序群組,並在每個程序群組內執行聯合運算,而獨立於其他程序群組之外。

在每個程序群組中,共同運算可能會造成同步處理障礙。更進一步的正式程序,例如分析此同步處理作業的確切發生時間、程序如何達到此障礙,以及不然會發生的情況 (#484)。

如果程序群組涉及跨分區通訊 (亦即程序群組中有分區 ID 不同的程序),則共同運算的執行作業需要管道,且共同運算必須提供 si64 類型的 channel_id。跨備用資源通訊不需要管道。

由共同運算執行的運算僅適用於個別運算,在上述的個別運算區段中說明。不過,程序格線拆分為程序群組採用的策略會在這些作業之間共用,本節將進行說明。更正式的,StableHLO 支援下列四種策略。

cross_replica

每個程序群組只會執行跨備用資源通訊。此策略會採用 replica_groups - 備用資源 ID 清單,並用 partition_ids 計算 replica_groups 的笛卡兒乘積。replica_groups 必須有不重複的元素,並涵蓋所有 replica_ids。更正式的做法是使用 Python 語法:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,針對 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica 會產生 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]

cross_partition

每個程序群組只會進行跨分區通訊。此策略會採用 partition_groups - 分區 ID 清單,並透過 replica_ids 計算 partition_groups 的笛卡兒乘積。partition_groups 必須有不重複的元素,並涵蓋所有 partition_ids。更正式,是使用 Python 語法:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,針對 partition_groups = [[0, 1]]num_replicas = 4cross_partition 會產生 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]

cross_replica_and_partition

每個程序群組內可能會同時進行跨備用資源和跨分區通訊。這項策略會採用 replica_groups - 備用資源 ID 清單,並以 partition_ids 計算每個 replica_group 的笛卡兒產品。replica_groups 必須有不重複的元素,並涵蓋所有 replica_ids。更正式,是使用 Python 語法:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

例如,針對 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica_and_partition 會產生 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]

flattened_ids

這項策略會採用 flattened_id_groups - 以 replica_id * num_partitions + partition_id 形式表示的「整併」程序 ID 清單,並將其轉換成程序 ID。flattened_id_groups 必須有不重複的元素,並涵蓋所有 process_ids。更正式,是使用 Python 語法:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

例如,如果是 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2flattened_ids 會產生 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]

準確

StableHLO 目前不保證數值準確度,但日後可能會改變。(#1156)。

錯誤

StableHLO 程式會透過個別作業的一系列限制來驗證 StableHLO 程式,該程式會在執行時間之前排除許多錯誤類別。 但是,錯誤狀況仍可能發生,例如透過整數溢位、外界存取等。除非明確呼叫,否則這些錯誤都會導致實作定義的行為,但這項設定日後可能會改變 (#1157)。

除了這項規則以外,StableHLO 程式中的浮點例外狀況還具有明確定義的行為。如果作業發生由 IEEE-754 標準定義的例外狀況 (無效作業、除以零、溢位、反向溢位或不完全相同的例外狀況),系統就會產生預設結果 (依標準定義),並在不提高對應的狀態旗標的情況下繼續執行作業,與標準的 raiseNoFlag 例外狀況處理類似。但定義了非標準運算 (例如複雜算術和特定傳輸函式) 的例外狀況。

Notation

為描述語法,本文件使用經過修改的 EBNF 語法 (ISO/IEC 14977:1996Wikipedia) 變體,其中有兩項修改:1) 規則是使用 ::= 而非 = 定義。

2) 串連表示要以其他形式表示,而不是 ,

為描述語意 (即在「類型」、「常數」和「Ops」區段內),我們使用以 Python 語法為基礎的公式,並支援簡要表示陣列運算,如下所示。這種方法適用於小片段程式碼片段,但在極少數情況下,如果需要較大的程式碼,我們會使用一律明確引入的香草 Python 語法。

公式

現在,讓我們根據 dot_general 規格的範例,瞭解公式的運作方式。這項作業的其中一個限制如下:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

這個公式使用的名稱來自兩個來源:1) 全域函式,即 dim、2) 對應程式元素的成員定義,即 dot_general 的「Inputs」區段中定義的 lhslhs_batching_dimensionsrhsrhs_batching_dimensions 輸入內容。

如上所述,這個公式的語法以 Python 為基礎,加上一些以精簡為導向的擴充功能。為了簡單說明這個公式 我們先將其轉換為變換的 Python 語法

A) 在這些公式中,我們使用 = 表示等式,因此取得 Python 語法的第一步將 = 替換為 ==,如下所示:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) 此外,這些公式支援刪節號 (...),可將純量運算式轉換成張量運算式。簡單來說,f(xs...) 大致是指「針對張量 xs 中的每個純量 x,計算純量 f(x),然後再將這些純量結果傳回做為張量結果」。在香草 Python 語法中,我們的公式範例會變為:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

藉助刪節號,通常可以避免在個別純量的層級工作。不過在棘手情況下,您可能會在 gather 規格的 start_indices[bi0, ..., :, ..., biN] 公式中使用較低層級的半不正式語法。簡而言之,我們無法提供將這類語法翻譯成香草 Python 的確切正式正規,原因是在每個情況下,依個別情況仍易於理解。如果部分公式看起來不透明,請告訴我們,我們會嘗試改善這些公式。

此外,您會注意到公式使用刪節號來擴充所有類型的清單,包括張量、張量清單 (例如因張數不同的張量清單) 等。這個領域又沒有確切的正式程度 (例如清單甚至不是 StableHLO 型別系統的一部分)

C) 我們採用的最終值得注意的車輛為隱性廣播。雖然 StableHLO 運算集不支援隱含廣播,但公式也可在精簡服務中使用。簡單來說,如果在預期有張量的情況下使用純量,純量就會播送至預期的形狀。

如要延續 dot_general 範例,以下是另一項限制:0 <= lhs_batching_dimensions < rank(lhs)。如 dot_general 規格所定義,lhs_batching_dimensions 是張量,但 0rank(lhs) 都是純量。套用隱式廣播後,公式會變成 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]

套用至特定 dot_general 運算時,這個公式會評估布林值的張量。將公式當做限制使用時,如果公式評估為 true 或只有 true 元素的張量,該限制就會生效。

名稱

在公式中,詞法範圍包括:1) 全域函式、2) 成員定義,

3) 當地定義。下方提供全域函式清單。元素定義清單取決於標記該標記的程式元素:

  • 針對作業,成員定義包含在「Inputs」(輸入) 和「Outputs」(輸出) 區段中導入的名稱。
  • 除此之外,成員定義都涵蓋程式元素的結構部分,並以對應的 EBNF 非終端機命名。在多數情況下,這些結構零件的名稱可透過將非終端的名稱轉換為蛇的大小寫 (例如 IntegerLiteral => integer_literal),但有時名稱會在程序中縮寫縮寫 (例如:QuantizationStorageType 輸出 => storage_type),在這種情況下,名稱在專有部分與「Inputs」(輸入)/「運算」 (運算) 相關。
  • 此外,成員定義一律包含 self,以參照對應的程式元素。

公式評估時會使用下列類型的值:1) Value (實際值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;這些值一律知道其類型)、2) Placeholder (未來值,例如 lhsrhsresult;其實際值尚未得知,只有類型才會知道類型)、3) Type (在「函式」一節中定義的類型)、「4」) Function (「函式」一節中定義的類型)。

根據內容而定,名稱可能參照不同的值。具體來說,運算 (以及其他程式元素的對等項目) 的「語意」區段會定義執行階段邏輯,因此所有輸入內容都能以 Value 的形式提供。相反地,運算 (和對等項目) 的「限制」部分會定義「編譯時間」邏輯,也就是通常在執行階段之前執行的項目,因此只有常數輸入內容可做為 Value 使用,其他輸入內容則只能做為 Placeholder 使用。

名稱 在「語意」中 在「限制」中
全域函式 Function Function
常數輸入內容 Value Value
非常數的輸入內容 Value Placeholder
輸出 Value Placeholder
本機定義 視定義而定 視定義而定

我們來看看一個 transpose 運算範例:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

這項作業的 permutation 是常數,因此可在語意和限制中以 Value 的形式提供。相反地,operandresult 可做為語意中的 Value 使用,但僅限做為限制條件中的 Placeholder

函式

類型建構

沒有可用來建構類型的函式。而是直接使用類型語法,因為這類語法通常較精簡。例如:(tensor<E>, tensor<E>) -> (tensor<E>),而非 function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])

類型的函式

  • element_type 是在張量類型和量化張量類型上定義,並分別回傳相應 TensorTypeQuantizedTensorTypeTensorElementTypeQuantizedTensorElementType 部分。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is not None 的捷徑。

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is None 的捷徑。

  • is_promotable(x: Type, y: Type) -> bool 會檢查 x 類型是否可升級為 y 類型。當 xyQuantizedTensorElementType 時,促銷活動只會套用至 storage_type。此特定升級版本目前用於減少運算作業 (詳情請參閱 RFC 相關說明)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Valueis_quantized_tensor_element_type(x) 的捷徑。

  • is_type_name(x: Value | Placeholder | Type) -> Value:適用於所有類型。舉例來說,如果 xFloatTypeis_float(x) 會傳回 true。如果 x 是值或預留位置,此函式是 is_type_name(type(x)) 的捷徑。

  • max_value(x: Type) -> Value 會傳回 TensorElementType 的最大值。如果 x 不是 TensorElementType,就會傳回 None

  • min_value(x: Type) -> Value 會傳回 TensorElementType 的最小可能值。如果 x 不是 TensorElementType,就會傳回 None

  • member_name(x: Value | Placeholder | Type) -> Any:適用於所有類型的成員定義 member_name。舉例來說,tensor_element_type(x) 會傳回對應 TensorTypeTensorElementType 部分。如果 x 是值或預留位置,此函式是 member_name(type(x)) 的捷徑。如果 x 類型不包含適當成員,或是該類型的值或預留位置,則會傳回 None

值的建構

  • operation_name(*xs: Value | Type) -> Value:適用於所有作業。舉例來說,add(lhs, rhs) 會取用 lhsrhs 兩個張量值,並傳回利用這些輸入內容評估 add 運算的輸出結果。對 broadcast_in_dim 等某些作業來說,其輸出類型為「載入傳輸」,也就是評估作業所需的項目。在此情況下,函式會將這些類型視為引數。

用於值的函式

  • 所有 Python 的運算子和函式皆可使用。例如,Python 的訂閱切割標記都可以為張量、量化張量和元組建立索引。

  • to_destination_type(x: Value, destination_type: Type) -> Value 是在張量上定義,並根據 type(x)destination_type 傳回 x 的轉換值,如下所示:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

目前我們已早期探討合併 convertuniform_quantizeuniform_dequantize 作業 (#1576)。合併後,我們不需要上述函式,可改用 convert 的作業名稱。

  • is_nan(x: Value) -> Value 是在張量上定義,如果 x 的所有元素都是 NaNfalse,則會傳回 true。如果 x 不是張量,會傳回 None

  • is_sorted(x: Value) -> Value 已在張量定義,如果 x 的元素是按照其索引的遞增字母順序排序,則會傳回 truefalse如果 x 不是張量,就會傳回 None

  • is_unique(x: Value) -> Value 是在張量上定義,如果 x 沒有重複的元素,則會傳回 true,否則會傳回 false。如果 x 不是張量,會傳回 None

  • member_name(x: Value) -> Any 已定義所有值的成員定義 member_name。舉例來說,real_part(x) 會傳回對應 ComplexConstantRealPart 部分。如果 x 的值不是包含適當成員的值,則會傳回 None

  • same(x: Value) -> Value 是在張量上定義,如果 x 的元素彼此相等,則傳回 true,否則會傳回 false。如果張量沒有元素,系統會將其視為「彼此相等」,即函式會傳回 true。如果 x 不是張量,就會傳回 None

  • split(x: Value, num_results: Value, axis: Value) -> Value 是在張量上定義,且會傳回軸 axis 軸的 x 配量。num_results如果 x 不是張量或 dim(x, axis) % num_results != 0,就會傳回 None

形狀計算

  • axes(x: Value | Placeholder | Type) -> Valuerange(rank(x)) 的捷徑。

  • dim(x: Value | Placeholder | Type, axis: Value) -> Valueshape(x)[axis] 的捷徑。

  • dims(x: Value | Placeholder | Type, axes: List) -> Listlist(map(lambda axis: dim(x, axis), axes)) 的捷徑。

  • index_space(x: Value | Placeholder | Type) -> Value 是在張量上定義,並會針對按字母順序排列的對應 TensorType 傳回 size(x) 索引,即 [0, ..., 0][0, ..., 1]、...、shape(x) - 1。如果 x 不是張量類型、量化張量類型、值或是其中一種類型的預留位置,就會傳回 None

  • rank(x: Value | Placeholder | Type) -> Valuesize(shape(x)) 的捷徑。

  • shape(x: Value | Placeholder | Type) -> Value 是透過 member_name 在「類型上的函式」部分中定義。

  • size(x: Value | Placeholder | Type) -> Valuereduce(lambda x, y: x * y, shape(x)) 的捷徑。

量化運算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Typeelement_type(baseline_type(x)) 的捷徑。

  • baseline_type 已在張量類型和量化張量類型上定義,並轉換為「基準」,即形狀相同但帶有量化參數的類型會重設為預設值。這很適合做為平均比較張量和量化張量類型比較的實用技巧。以量化類型來說,這可讓比較忽略量化參數的類型 (shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension (針對每軸量化類型) 必須全部相符,但 scaleszero points 則可能有所不同。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize 是在量化張量類型上定義,並將其轉換為浮點張量類型。方法是將代表儲存空間類型的整數值,轉換為以零點和量化元素類型相關聯的量化元素,轉換為表示類型的對應浮點值。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize 是在浮點張量類型上定義,並轉換為量化張量類型。方法是使用與量化元素類型相關聯的零點和縮放,將表示類型的浮點值轉換為儲存類型的對應整數值。
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize 可用來在量化張量上指定元素相關計算。這樣做會去量化,即將量化元素轉換成其表達類型,然後執行運算,然後量化結果 (即將結果轉回其儲存空間類型)。目前此函式僅適用於按張量量化。系統仍在處理每軸量化。(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

網格運算

  • cross_partition(replica_groups: Value) -> Value。請參閱上述的「cross_copy」部分。

  • cross_replica(replica_groups: Value) -> Value。請參閱上述的「cross_copy」部分。

  • cross_replica_and_partition(replica_groups: Value) -> Value。請參閱上方的「cross_pli_and_partition」一節。

  • flattened_ids(replica_groups: Value) -> Value。請參閱上述的「flattened_ids」一節。