StableHLO 是機器學習 (ML) 模型中高階作業 (HLO) 的作業集。StableHLO 可做為不同機器學習架構和機器學習編譯器之間的可攜層:產生 StableHLO 程式的機器學習架構,與使用 StableHLO 程式的機器學習編譯器相容。
我們的目標是在不同的機器學習架構 (例如 TensorFlow、JAX 和 PyTorch) 與機器學習編譯器 (例如 XLA 和 IREE) 之間建立更互通性,藉此簡化及加快機器學習開發作業。因此,本文件提供 StableHLO 程式設計語言的規格。
這個規格包含三個主要部分首先,「程式」一節說明瞭 StableHLO 程式的結構,StableHLO 函式本身是由 StableHLO 運算組成。在這個結構中,「Ops」區段會指定個別運算的語意。「Execution」一節會針對在程式中一起執行的所有作業提供語意。最後,「標記法」一節將討論整個規格中使用的標記法。
程式
Program ::= {Func}
StableHLO 程式包含任意數量的 StableHLO 函式。以下是含有 @main
函式的範例程式,函式包含 3 個輸入 (%image
、%weights
和 %bias
) 和 1 個輸出內容。函式主體有 6 個運算。
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
函式
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO 函式 (也稱為「已命名函式」) 具有 ID、輸入/輸出和主體。日後,我們計劃為函式推出其他中繼資料,以便提高與 HLO 的相容性 (#425、#626、#740、#744)。
ID
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO ID 與許多程式設計語言中的 ID 相似,但基本上有兩大意義:1) 所有 ID 都有能夠區分不同種類的 ID,2) 值 ID 可以是完全數字,簡化 StableHLO 程式的產生作業。
類型
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO 類型分為「值類型」 (也稱為「一類類型」),代表 StableHLO 值和描述其他程式元素的非值類型。StableHLO 類型與許多程式設計語言的類型類似,主要謹慎度是 StableHLO 的特定領域,導致某些不尋常的結果 (例如純量類型不是值類型)。
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
Tensor 類型代表張量,即多維陣列。它們具有形狀和元素類型,其中形狀代表非負的維度大小,按照 0
到 R-1
的對應維度 (也稱為「軸」) 遞增順序排列。維度數量 R
稱為「排名」。舉例來說,tensor<2x3xf32>
是形狀類型為 2x3
且元素類型為 f32
的張量類型。這項特徵有兩個維度 (或兩個軸,也就是兩個軸) - 第 0 個維度和 1 個維度,大小為 2 和 3。其排名為 2。
這會定義對靜態形狀的支援,其中尺寸是靜態的。未來,我們也計劃引進對動態形狀的支援,其中尺寸大小為部分或完全不明 (#8)。此外,我們也打算探索除了尺寸大小和元素類型以外的張量類型,例如納入版面配置 (#629) 和稀疏度 (#1078)。
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
名稱 | 類型 | 限制 |
---|---|---|
storage_type |
整數類型 | (C1-C4)、(C9) |
storage_min |
整數常數 | (C2)、(C4)、(C8) |
storage_max |
整數常數 | (C3)、(C4)、(C8) |
expressed_type |
浮點類型 | (C1)、(C5) |
quantization_dimension |
選擇性整數常數 | (C11-C13) |
scales |
變異數浮點常數 | (C5-C7)、(C10)、(C11)、(C13) |
zero_points |
變異數整數常數 | (C8-C10) |
「量化元素類型」代表儲存空間類型在 storage_min
到 storage_max
(含) 之間的整數值,對應於表達類型的浮點值。以指定整數值 i
來說,對應的浮點值 f
可視為 f = (i - zero_point) * scale
計算,其中 scale
和 zero_point
稱為量化參數。文法中的 storage_min
和 storage_max
為選用項目,但預設值分別為 min_value(storage_type)
和 max_value(storage_type)
。量化元素類型具有下列限制:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
。 - (C2)
type(storage_min) = storage_type
。 - (C3)
type(storage_max) = storage_type
。 - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
。 - (C5)
type(scales...) = expressed_type
。 - (C6)
0 < scales
。 - (C7)
is_finite(scales...)
。 - (C8)
storage_min <= zero_points <= storage_max
。 - (C9)
type(zero_points...) = storage_type
。 - (C10)
size(scales) = size(zero_points)
。 - (C11) 如為
is_empty(quantization_dimension)
,則預設為size(scales) = 1
。 - (C12)
0 <= quantization_dimension
。
目前,QuantizationScale
是浮點常數,但對於以整數為基礎的量表來說,以乘法和位移表示。我們計劃在未來進行探索。(#1404)。
有關 QuantizationZeroPoint
的語意持續討論,包括類型、值,以及量化張量類型中是否只能有一個或可能多個零點。根據本討論的結果,零點左右的規格日後可能會有所異動 (#1405)。
另一項持續進行中的討論涉及 QuantizationStorageMin
和 QuantizationStorageMax
的語意,以判斷是否應對這些值和量化張量設定任何限制 (#1406)。
最後,我們打算研究代表未知量表和零點,做法類似於計劃呈現不明維度大小 (#1407)。
量化張量類型代表有量化元素的張量。這些張量與一般張量相同,但其元素具有量化元素類型,而非一般元素類型。
在量化張量中,量化可以是「各張量」,也就是為整個張量包含一個 scale
和 zero_point
,或者可以是每軸,即具有多個 scales
和 zero_points
,每個維度 quantization_dimension
的每個配量一個對組。正式上例,在採用每軸量化的張量 t
中,quantization_dimension
有 dim(t, quantization_dimension)
配量:t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
等。i
片段中的所有元素都會使用 scales[i]
和 zero_points[i]
做為量化參數。量化張量類型具有下列限制:
- 針對每張量量化:
- 沒有其他限制條件。
- 如為每軸量化:
- (C12)
quantization_dimension < rank(self)
。 - (C13)
dim(self, quantization_dimension) = size(scales)
。
- (C12)
TokenType ::= 'token'
「權杖類型」代表符記,即部分作業產生和使用的不透明值。如執行一節所述,憑證的用途是強制執行作業的執行順序。
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
元組類型代表元組,即異質清單。Tuples 是舊版功能,只為與 HLO 相容而存在。在 HLO 中,元組用於表示各變的輸入和輸出內容。在 StableHLO 中,系統原生支援變異的輸入和輸出,而且 StableHLO 中元組的唯一使用方法是全面代表 HLO ABI,例如 T
、tuple<T>
和 tuple<tuple<T>>
可能會因特定實作方式而有明顯差異。我們日後計劃變更 HLO ABI,這樣或許能從 StableHLO 移除元組類型 (#598)。
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
元素類型代表張量類型的元素。與許多程式設計語言不同,這些類型並不是 StableHLO 中的第一類別。這表示 StableHLO 程式無法直接表示這些類型的值 (因此,以 tensor<T>
類型的 0 個維度張量值表示 T
類型的純量值是慣用做法。
- 「布林值類型」代表布林值
true
和false
。 - 整數類型可以是正負號 (
si
) 或無正負號 (ui
),而且包含支援的其中一種位元寬度 (4
、8
、16
、32
或64
)。已簽署的siN
類型代表從-2^(N-1)
到2^(N-1)-1
的整數值,無正負號uiN
類型代表從0
到2^N-1
之間的整數值。 - 浮點類型可以是下列任一值:
- 對應到 FP8 格式的
E4M3
和E5M2
編碼的f8E4M3FN
和f8E5M2
類型,如「深度學習的 FP8 格式」所述。 - 與
E4M3
和E5M2
編碼 FP8 格式對應的f8E4M3FNUZ
和f8E5M2FNUZ
類型,如「深層類神經網路的 8 位元數值格式」所述。 f8E4M3B11FNUZ
類型,對應於「適用於深層類神經網路的混合式 8 位元浮點 (HFP8) 訓練和推論」一文中說明的 FP8 格式E4M3
編碼。- 與「BFloat16:在 Cloud TPU 上發揮高效能」一節所述
bfloat16
格式的bf16
類型。 f16
、f32
和f64
類型分別對應到binary16
(「半精確度」、「單一精確度」) 和binary64
(「雙精確度」) 格式,如 IEEE 754 標準所述。binary32
- 對應到 FP8 格式的
- 「複雜類型」代表具有「實際部分」以及屬於相同元素類型的虛部分的複雜值。支援的複雜類型為
complex<f32>
(這兩個部分都是f32
類型) 和complex<f64>
(這兩個部分都是f64
類型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
函式類型代表已命名和匿名函式。它們具有輸入類型 (->
左側的類型清單) 和輸出類型 (->
右側的類型清單)。在許多程式設計語言中,函式類型都是第一類,但在 StableHLO 中則不是。
StringType ::= 'string'
字串類型代表位元組序列。與許多程式設計語言不同,字串類型並非 StableHLO 中的第一個類別,只會用來為程式元素指定靜態中繼資料。
作業套件
StableHLO 作業 (也稱為「作業」) 代表機器學習模型中的一組高階作業。如上所述,StableHLO 語法非常受到 MLIR 啟發,這不一定是最能體工學的替代方法,但或許最適合 StableHLO 的目標是在機器學習架構與機器學習編譯器之間建立更互通性的目標。
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO 作業 (也稱為「操作」) 具有名稱、輸入/輸出和簽名。名稱包含 stablehlo.
前置字元和一個可明確識別其中一項支援的運算的記憶。請參閱下方的完整支援作業清單。
目前,野外的 StableHLO 程式有時會包含本文未提及的作業。未來我們打算將這些作業吸收到 StableHLO 運算組中,或禁止這些作業出現在 StableHLO 程式中。與此同時,以下是這些作業的清單:
builtin.module
、func.func
、func.call
和func.return
(#425)。chlo
運算 (#602)。- 「不在 HLO」 StableHLO 運算類別 - 最初是 StableHLO 運算集的一部分,但後來被視為不適合其:
broadcast
、create_token
、cross-replica-sum
、dot
、einsum
、torch_index_select
、unary_einsum
(#3)。 - StableHLO 運算的「Dynamism」類別是由 MHLO 自行啟動,但我們尚未找到:
compute_reshape_shape
、cstr_reshapable
、dynamic_broadcast_in_dim
、dynamic_conv
、dynamic_gather
、dynamic_iota
、dynamic_pad
、dynamic_reshape
、real_dynamic_slice
、set_dimension_size
(#8)。 - 形狀計算,包括
arith
、shape
和tensor
運算 (#8)。
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
作業會取用輸入並產生「輸出」。輸入內容分為輸入值 (執行期間計算)、輸入函式 (因為在 StableHLO 函式中並非第一類別值) 和輸入屬性 (同樣以靜態方式提供)。運算使用和產生的輸入與輸出種類取決於其記憶。舉例來說,add
運算會使用 2 個輸入值並產生 1 個輸出值。相較之下,select_and_scatter
運算會使用 3 個輸入值、2 個輸入函式和 3 個輸入屬性。
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
輸入函式 (又稱為「匿名函式」) 與已命名函式非常類似,但前者沒有 ID (因此稱為「匿名」),2) 它們不會宣告輸出類型 (從函式中的 return
運算推斷輸出類型)。
輸入函式的語法包含目前未使用的部分 (請參閱上方的 Unused
實際工作環境),該部分可與 MLIR 相容。在 MLIR 中,有一個較通用的「地區」概念,可以有多個「區塊」透過跳躍運算連結。這些區塊的 ID 與 Unused
實際工作環境相對應,因此可以彼此區分。StableHLO 沒有跳躍作業,因此不會使用 MLIR 語法的對應部分 (但仍然存在)。
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
輸入屬性具有名稱和值,這是支援的常數之一。是為節目元素指定靜態中繼資料的主要方式。舉例來說,concatenate
運算會使用 dimension
屬性指定串連輸入值的維度。同樣地,slice
運算會使用 start_indices
和 limit_indices
等多個屬性來指定用於切分輸入值的邊界。
目前,野外的 StableHLO 程式有時會包含本文未說明的屬性。未來我們打算將這些屬性吸收到 StableHLO 運算組中,或禁止這些屬性出現在 StableHLO 程式中。同時,以下是這些屬性的清單:
layout
(#629)。mhlo.frontend_attributes
(#628)。mhlo.sharding
(#619)。output_operand_aliases
(#740)。- 位置中繼資料 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
「運算簽名」由所有輸入值的類型 (->
左側的類型清單) 和所有輸出值的類型 (->
右側的類型清單) 組成。嚴格說來,輸入類型都是多餘的,輸出類型也幾乎都是多餘的 (因為多數 StableHLO 運算的輸出類型可根據輸入推論)。不過,運算簽名是刻意安排為 StableHLO 語法的一部分,以與 MLIR 相容。
以下是記憶為 select_and_scatter
的運算範例。這個公式會使用 3 個輸入值 (%operand
、%source
和 %init_value
)、2 個輸入函式和 3 個輸入屬性 (window_dimensions
、window_strides
和 padding
)。請注意,運算的簽章僅包含其輸入值的類型 (但不含內嵌提供的輸入函式和屬性類型)。
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
常數
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO 常數具有常值和類型,一起代表 StableHLO 值。一般來說,這個類型是常數語法的一部分,除非是不明確的情況 (例如布林值常數含有 i1
類型,而整數常數可以有多種可能類型)。
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
布林值常數代表布林值 true
和 false
。布林常數具有 i1
類型。
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
整數常數會透過使用十進位或十六進位標記法的字串來表示整數值。系統不支援其他底數 (例如二進位或八進位)。 整數常數具有下列限制:
- (C1)
is_wellformed(integer_literal, integer_type)
。
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
浮點常數透過使用小數或科學標記法的字串表示浮點值。此外,十六進位標記法也可用來直接指定對應類型浮點格式的基礎位元。浮點常數具有下列限制:
- (C1) 如果使用非十六進位標記法,請設為
is_wellformed(float_literal, float_type)
。 - (C2) 如果使用十六進位標記法,則
size(hexadecimal_digits) = num_bits(float_type) / 4
。
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
複雜常數會使用實際部分 (先推出) 和虛部分 (第二) 的清單表示複雜的值。例如,(1.0, 0.0) : complex<f32>
代表 1.0 + 0.0i
,(0.0, 1.0) : complex<f32>
則代表 0.0 + 1.0i
。這些部分在記憶體中的儲存順序是由實作定義。複雜常數具有下列限制:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
。 - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
。
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Tensor 常數表示使用透過 NumPy 標記指定的巢狀清單表示張量值。舉例來說,dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
代表張量值,下列從索引對應至元素:{0, 0} => 1
、{0, 1} => 2
、{0, 2} => 3
、{1, 0} => 4
、{1, 1} => 5
、{1, 2} => 6
。這些元素會在記憶體中的儲存順序實作。張量常數具有下列限制:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
,其中:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
,其中:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- 否則為
false
。
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
「量化張量常數」代表量化張量值,並使用與張量常數相同的標記法表示,而且元素會指定為其儲存空間類型的常數。量化張量常數具有下列限制:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
。 - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
。
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
字串常值由使用 ASCII 字元和逸出序列指定的位元組組成。這些位元組各不具有編碼功能,因此對這些位元組的解釋是由實作定義。字串常值具有 string
類型。
作業數
abs
語義
對 operand
張量執行元素層級抽象運算,並產生 result
張量。根據元素類型執行以下操作:
- 帶正負號整數:整數模數。
- 浮點值:由 IEEE-754 距離
abs
。 - 複數:複數的模數。
- 量化類型:
dequantize_op_quantize(abs, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 | (C1-C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
帶正負號整數、浮點類型,或每個張量量化張量的張量 | (C1-C2) |
限制
- (C1)
shape(result) = shape(operand)
。 - (C2)
baseline_element_type(result)
定義為:- 如果
is_complex(operand)
則為complex_element_type(element_type(operand))
。 - 否則傳回
baseline_element_type(operand)
。
- 如果
範例
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
語義
對兩個張量 lhs
和 rhs
執行元素相關新增作業,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 OR。
- 整數:加整數。
- 浮點值:由 IEEE-754 距離
addition
。 - 複數:複數的加法。
- 量化類型:
dequantize_op_quantize(add, lhs, rhs, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1)。 |
(I2)。 | rhs |
張量或每個張量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
語義
確保產生 inputs
的作業在依附 result
的任何作業之前執行。這項作業不會執行任何動作,系統只會建立從 result
到 inputs
的資料依附元件。
輸入內容
標籤 | 名稱 | 類型 |
---|---|---|
(I1)。 | inputs |
變異數token |
輸出
名稱 | 類型 |
---|---|
result |
token |
範例
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
語義
在 StableHLO 程序格線中的每個程序群組中,按照 all_gather_dim
將每個程序的 operand
張量值串連起來,然後產生 result
張量。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
cross_replica(replica_groups)
如果為channel_id <= 0 and use_global_device_ids = false
。cross_replica_and_partition(replica_groups)
如果為channel_id > 0 and use_global_device_ids = false
。flattened_ids(replica_groups)
如果為channel_id > 0 and use_global_device_ids = true
。
之後,在每個 process_group
中:
operands@receiver = [operand@sender for sender in process_group]
適用於process_group
中所有receiver
。result@process = concatenate(operands@process, all_gather_dim)
適用於process_group
中所有process
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C6) |
(I2)。 | all_gather_dim |
si64 類型的常數 |
(C1)、(C6) |
(I3)。 | replica_groups |
si64 類型的 2 維張量常數 |
(C2-C4) |
(I4)。 | channel_id |
si64 類型的常數 |
(C5) |
(I5) | use_global_device_ids |
i1 類型的常數 |
(C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C6) |
限制
- (C1)
0 <= all_gather_dim < rank(operand)
。 - (C2)
is_unique(replica_groups)
。 - (C3)
size(replica_groups)
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_replica_and_partition
,則為num_replicas
。 - 如果使用
flattened_ids
,則為num_processes
。
- 如果使用
- (C4)
0 <= replica_groups < size(replica_groups)
。 - (C5) 如為
use_global_device_ids = true
,則預設為channel_id > 0
。 - (C6)
type(result) = type(operand)
,但下列項目除外:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
範例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
語義
在 StableHLO 程序格線中的每個程序群組中,將縮減函式 computation
套用至每個程序中的 operand
張量值,然後產生 result
張量。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
cross_replica(replica_groups)
如果為channel_id <= 0 and use_global_device_ids = false
。cross_replica_and_partition(replica_groups)
如果為channel_id > 0 and use_global_device_ids = false
。flattened_ids(replica_groups)
如果為channel_id > 0 and use_global_device_ids = true
。
之後,在每個 process_group
中:
result@process[result_index] = exec(schedule)
代表二進位樹狀結構schedule
,其中:exec(node)
=computation(exec(node.left), exec(node.right))
。exec(leaf)
=leaf.value
。
schedule
是實作定義的二進位樹狀結構,其順序週遊為to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C5)、(C6) |
(I2)。 | replica_groups |
si64 類型的 1 維張量常數的變異數 |
(C1-C3) |
(I3)。 | channel_id |
si64 類型的常數 |
(C4) |
(I4)。 | use_global_device_ids |
i1 類型的常數 |
(C4) |
(I5) | computation |
函式 | (C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C6-C7) |
限制
- (C1)
is_unique(replica_groups)
。 - (C2)
size(replica_groups)
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_replica_and_partition
,則為num_replicas
。 - 如果使用
flattened_ids
,則為num_processes
。
- 如果使用
- (C3)
0 <= replica_groups < size(replica_groups)
。 - (C4) 如為
use_global_device_ids = true
,則預設為channel_id > 0
。 - (C5)
computation
具有類型(tensor<E>, tensor<E>) -> (tensor<E>)
,其中is_promotable(element_type(operand), E)
。 - (C6)
shape(result) = shape(operand)
。 - (C7)
element_type(result) = E
。
範例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
語義
在 StableHLO 程序格線的每個程序群組中,將 operand
張量的值沿著 split_dimension
分割為部分,將分割部分分散到程序之間,沿著 concat_dimension
串連分散的部分,然後產生 result
張量。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
- 如果
channel_id <= 0
則為cross_replica(replica_groups)
。 - 如果
channel_id > 0
則為cross_partition(replica_groups)
。
之後,在每個 process_group
中:
split_parts@sender = split(operand@sender, split_count, split_dimension)
適用於process_group
中所有sender
。scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
,即receiver_index = process_group.index(receiver)
。result@process = concatenate(scattered_parts@process, concat_dimension)
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1-C3)、(C9) |
(I2)。 | split_dimension |
si64 類型的常數 |
(C1)、(C2)、(C9) |
(I3)。 | concat_dimension |
si64 類型的常數 |
(C3)、(C9) |
(I4)。 | split_count |
si64 類型的常數 |
(C2)、(C4)、(C8)、(C9) |
(I5) | replica_groups |
si64 類型的 2 維張量常數 |
(C5-C8) |
(I6)。 | channel_id |
si64 類型的常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C9) |
限制
- (C1)
0 <= split_dimension < rank(operand)
。 - (C2)
dim(operand, split_dimension) % split_count = 0
。 - (C3)
0 <= concat_dimension < rank(operand)
。 - (C4)
0 < split_count
。 - (C5)
is_unique(replica_groups)
。 - (C6)
size(replica_groups)
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_partition
,則為num_partitions
。
- 如果使用
- (C7)
0 <= replica_groups < size(replica_groups)
。 - (C8)
dim(replica_groups, 1) = split_count
。 - (C9)
type(result) = type(operand)
,但下列項目除外:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
範例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
以及
語義
以 lhs
和 rhs
兩個張量為目標執行元素 AND,然後產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 AND。
- 整數:位元 AND。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
布林值或整數類型的張量 | (C1)。 |
(I2)。 | rhs |
布林值或整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
布林值或整數類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
語義
對 lhs
和 rhs
張量執行元素相關 atan2 作業,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
atan2
。 - 複數:複數:複數。
- 量化類型:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
(I2)。 | rhs |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
語義
計算從 grad_output
反向傳播 batch_norm_training
的數個輸入梯度,並產生 grad_operand
、grad_scale
和 grad_offset
張量。更正式地說,這個運算可透過 Python 語法以分解的形式表示到現有 StableHLO 作業,如下所示:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
如果是量化類型,請執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1-C3)、(C5) |
(I2)。 | scale |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4)、(C5) |
(I3)。 | mean |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4) |
(I4)。 | variance |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4) |
(I5) | grad_output |
浮點類型或每個張量量化張量的張量 | (C2)、(C3) |
(I6)。 | epsilon |
f32 類型的常數 |
|
(I7) | feature_index |
si64 類型的常數 |
(C1)、(C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
grad_operand |
浮點類型或每個張量量化張量的張量 | (C2)、(C3) |
grad_scale |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4) |
grad_offset |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4) |
限制
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、mean
、variance
、grad_output
、grad_operand
、grad_scale
和grad_offset
有相同的baseline_element_type
。 - (C3)
operand
、grad_output
和grad_operand
形狀相同。 - (C4)
scale
、mean
、variance
、grad_scale
和grad_offset
的形狀相同。 - (C5)
size(scale) = dim(operand, feature_index)
。
範例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
語義
將所有維度的 operand
張量正規化 (feature_index
維度除外),並產生 result
張量。更正式的,這個運算可透過 Python 語法以分解的形式表示到現有 StableHLO 作業,如下所示:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
如果是量化類型,請執行 dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1-C7) |
(I2)。 | scale |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C3) |
(I3)。 | offset |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C4) |
(I4)。 | mean |
浮點或每個張量量化類型的 1 維張量 | (C5) |
(I5) | variance |
浮點或每個張量量化類型的 1 維張量 | (C2)、(C6) |
(I6)。 | epsilon |
f32 類型的常數 |
|
(I7) | feature_index |
si64 類型的常數 |
(C1)、(C3-C6) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型或每個張量量化張量的張量 | (C2)、(C7) |
限制
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、offset
、mean
、variance
和result
有相同的baseline_element_type
。 - (C3)
size(scale) = dim(operand, feature_index)
。 - (C4)
size(offset) = dim(operand, feature_index)
。 - (C5)
size(mean) = dim(operand, feature_index)
。 - (C6)
size(variance) = dim(operand, feature_index)
。 - (C7)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
語義
計算所有維度的平均值和變異數,但 feature_index
維度除外,並將產生 output
、batch_mean
和 batch_var
張量的 operand
張量正規化。更正式的做法是,使用 Python 語法以分解為現有 StableHLO 作業的形式表示,如下所示:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
如果是量化類型,請執行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
(I2)。 | scale |
浮點或每個張量的 1 維張量 | (C2)、(C3) |
(I3)。 | offset |
浮點或每個張量的 1 維張量 | (C2)、(C4) |
(I4)。 | epsilon |
f32 類型的常數 |
(C1)、(C3-C6) |
(I5) | feature_index |
si64 類型的常數 |
(C1)、(C3-C6) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
output |
浮點類型或每個張量量化張量的張量 | (C7) |
batch_mean |
浮點或每個張量的 1 維張量 | (C2)、(C5) |
batch_var |
浮點或每個張量的 1 維張量 | (C2)、(C6) |
限制
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、offset
、batch_mean
、batch_var
和output
有相同的baseline_element_type
。 - (C3)
size(scale) = dim(operand, feature_index)
。 - (C4)
size(offset) = dim(operand, feature_index)
。 - (C5)
size(batch_mean) = dim(operand, feature_index)
。 - (C6)
size(batch_var) = dim(operand, feature_index)
。 - (C7)
baseline_type(output) = baseline_type(operand)
。
範例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
語義
對 operand
張量執行 Bitcast 作業,並產生 result
張量,其中整個 operand
張量的位元會使用 result
張量類型重新解譯。
更正式,有 E = element_type(operand)
、E' = element_type(result)
和 R = rank(operand)
:
- 如果值為
num_bits(E') < num_bits(E)
,則為bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
。 - 如果值為
num_bits(E') > num_bits(E)
,則為bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
。 - 如果值為
num_bits(E') = num_bits(E)
,則為bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
。
bits
會傳回指定值的記憶體內表示法,且其行為是由實作定義,因為實作方式已定義實際代表張量,且元素類型的確切表示法也經過實作定義。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或量化張量 | (C1-C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或量化張量 | (C1-C2) |
限制
- (C1) 如
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
、E' = is_quantized(result) ? storage_type(result) : element_type(result)
和R = rank(operand)
:- 如果值為
num_bits(E') = num_bits(E)
,則值為shape(result) = shape(operand)
。 - 如為
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.- 所有
0 <= i < R
的dim(result, i) = dim(operand, i)
。 dim(result, R) * num_bits(E') = num_bits(E)
.- 如為
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.- 所有
0 <= i < R
的dim(result, i) = dim(operand, i)
。 dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- 如果值為
- (C2) 如果為
is_complex(operand) or is_complex(result)
,則is_complex(operand) and is_complex(result)
。
範例
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
語義
複製 operand
張量中的資料並產生 result
張量,以展開輸入張量的維度和/或排名。更正式的 result[result_index] = operand[operand_index]
,其中 axes(operand)
中所有 d
:
- 如果
dim(operand, d) = 1
則為operand_index[d] = 0
。 - 否則傳回
operand_index[d] = result_index[broadcast_dimensions[d]]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或量化張量 | (C1-C2)、(C5-C6) |
(I2)。 | broadcast_dimensions |
si64 類型的 1 維張量常數 |
(C2-C6) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或量化張量 | (C1)、(C3)、(C5-C6) |
限制
- (C1)
element_type(result)
的提供者:element_type(operand)
(如果為!is_per_axis_quantized(operand)
)。element_type(operand)
除外,但quantization_dimension(operand)
、scales(operand)
和zero_points(operand)
可能與quantization_dimension(result)
、scales(result)
和zero_points(result)
不同。
- (C2)
size(broadcast_dimensions) = rank(operand)
。 - (C3)
0 <= broadcast_dimensions < rank(result)
。 - (C4)
is_unique(broadcast_dimensions)
。 - (C5) 針對
axes(operand)
中的所有d
:dim(operand, d) = 1
或dim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) 如果為
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- 如果為
dim(operand, quantization_dimension(operand)) = 1
,則值為scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
。
範例
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
客服案件
語義
根據 index
的值,從 branches
僅執行一個函式產生輸出內容。正式來說,result = selected_branch()
其中:
- 如果
0 <= index < size(branches)
則為selected_branch = branches[index]
。 - 否則傳回
selected_branch = branches[-1]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | index |
si32 類型的 0 維張量 |
|
(I2)。 | branches |
變異函數數量 | (C1-C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量、量化張量或符記 | (C4) |
限制
- (C1)
0 < size(branches)
。 - (C2)
input_types(branches...) = []
。 - (C3)
same(output_types(branches...))
。 - (C4)
type(results...) = output_types(branches[0])
。
範例
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
語義
對 operand
張量執行元素相關立方根作業,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
rootn(x, 3)
。 - 複數:複雜的立方根。
- 量化類型:
dequantize_op_quantize(cbrt, operand, type(result))
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
水泥
語義
執行 operand
張量的元素圓形,並產生 result
張量。從 IEEE-754 規格實作 roundToIntegralTowardPositive
作業。如果是量化類型,請執行 dequantize_op_quantize(ceil, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
白洞
語義
計算一批矩陣的膽天分解法。
更確切地說,對於 index_space(result)
中的所有 i
,result[i0, ..., iR-3, :, :]
是 a[i0, ..., iR-3, :, :]
的孔洞分解,採用下三角形 (如果 lower
為 true
) 或上三角形 (如果 lower
為 false
) 矩陣。相對三角形的輸出值 (對應所對應的嚴格上限三角形或嚴格三角形) 的輸出值是實作定義。
如果有 i
的輸入矩陣不是隱士頓陽性矩陣,則行為將未定義。
如果是量化類型,請執行 dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | a |
浮點、複雜類型或每個張量化張量的張量 | (C1-C3) |
(I2)。 | lower |
i1 類型的 0 維張量常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(a) = baseline_type(result)
。 - (C2)
2 <= rank(a)
。 - (C3)
dim(a, -2) = dim(a, -1)
。
範例
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
限制取值範圍
語義
將 operand
張量的每個元素限制在最小值和最大值之間,產生 result
張量。更正式的是 result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,其中 min_element = rank(min) = 0 ? min[] : min[result_index]
、max_element = rank(max) = 0 ? max[] : max[result_index]
。如果是量化類型,則執行 dequantize_op_quantize(clamp, min, operand, max, type(result))
。
在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | min |
張量或每個張量化張量 | (C1)、(C3) |
(I2)。 | operand |
張量或每個張量化張量 | (C1-C4) |
(I3)。 | max |
張量或每個張量化張量 | (C2)、(C3) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C4) |
限制
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
。 - (C2)
rank(max) = 0 or shape(max) = shape(operand)
。 - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
。 - (C4)
baseline_type(operand) = baseline_type(result)
。
範例
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
語義
在 StableHLO 程序格線中的每個程序群組中,將 operand
張量的值從來源程序傳送至目標程序,並產生 result
張量。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
- 如果
channel_id <= 0
則為cross_replica(replica_groups)
。 - 如果
channel_id > 0
則為cross_partition(replica_groups)
。
之後,result@process
將由以下提供:
- 如果有
i
,因此程序位於process_groups[i]
,則為operand@process_groups[i, 0]
。 broadcast_in_dim(constant(0, element_type(result)), [], type(result))
否則。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量 | (C3) |
(I2)。 | replica_groups |
si64 類型的 1 維張量常數的變異數 |
(C1)、(C2) |
(I3)。 | channel_id |
si64 類型的常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量 | (C3) |
限制
- (C1)
is_unique(replica_groups)
。 - (C2)
0 <= replica_groups < N
,其中N
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_partition
,則為num_partitions
。
- 如果使用
- (C3)
type(result) = type(operand)
。
範例
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
語義
在 StableHLO 程序格線中的每個程序群組中,將 operand
張量的值從來源程序傳送至目標程序,並產生 result
張量。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
- 如果
channel_id <= 0
則為cross_replica(source_target_pairs)
。 - 如果
channel_id > 0
則為cross_partition(source_target_pairs)
。
之後,result@process
將由以下提供:
operand@process_groups[i, 0]
(如果i
存在這類process_groups[i, 1] = process
)。broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
否則。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C5) |
(I2)。 | source_target_pairs |
si64 類型的 2 維張量常數 |
(C1-C4) |
(I3)。 | channel_id |
si64 類型的常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
dim(source_target_pairs, 1) = 2
。 - (C2)
is_unique(source_target_pairs[:, 0])
。 - (C3)
is_unique(source_target_pairs[:, 1])
。 - (C4)
0 <= source_target_pairs < N
,其中N
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_partition
,則為num_partitions
。
- 如果使用
- (C5)
type(result) = type(operand)
。
範例
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
語義
根據 comparison_direction
和 compare_type
對 lhs
和 rhs
張量執行元素層級比較,並產生 result
張量。
comparison_direction
和 compare_type
的值包含下列語意:
布林值和整數元素類型:
EQ
:lhs = rhs
。NE
:lhs != rhs
。GE
:lhs >= rhs
。GT
:lhs > rhs
。LE
:lhs <= rhs
。LT
:lhs < rhs
。
如果是含有 compare_type = FLOAT
的浮點元素類型,運算會實作下列 IEEE-754 作業:
EQ
:compareQuietEqual
。NE
:compareQuietNotEqual
。GE
:compareQuietGreaterEqual
。GT
:compareQuietGreater
。LE
:compareQuietLessEqual
。LT
:compareQuietLess
。
如果是具有 compare_type = TOTALORDER
的浮點元素類型,這項運算會使用 IEEE-754 的 totalOrder
和 compareQuietEqual
運算。這項功能似乎未使用,因此我們計劃在日後將其移除 (#584)。
如果是複雜的元素類型,系統會使用提供的 comparison_direction
和 compare_type
對 (real, imag)
組合的字母組合比較。在複雜數字上排序複雜數字會帶來令人意想不到的語意,因此我們未來計劃在 comparison_direction
為 GE
、GT
、LE
或 LT
時移除對複數的支援 (#560)。
如果是量化類型,則執行 dequantize_compare(lhs, rhs,
comparison_direction)
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1-C3) |
(I2)。 | rhs |
張量或每個張量化張量 | (C1-C2) |
(I3)。 | comparison_direction |
EQ 、NE 、GE 、GT 、LE 和 LT 的列舉項目 |
|
(I4)。 | compare_type |
FLOAT 、TOTALORDER 、SIGNED 和 UNSIGNED 的列舉 |
(C3) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
布林類型的張量 | (C2) |
限制
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
。 - (C2)
shape(lhs) = shape(rhs) = shape(result)
。 - (C3)
compare_type
定義為:- 如果
is_signed_integer(element_type(lhs))
則為SIGNED
。 - 如果
is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
則為UNSIGNED
。 FLOAT
或TOTALORDER
表示is_float(element_type(lhs))
。- 如果
is_complex(element_type(lhs))
則為FLOAT
。
- 如果
範例
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
複雜
語義
從實際和虛值 (lhs
和 rhs
) 的一對複雜值執行元素相關轉換,然後產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
f32 或 f64 類型的張量 |
(C1-C3) |
(I2)。 | rhs |
f32 或 f64 類型的張量 |
(C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
複雜類型的張量 | (C2)、(C3) |
限制
- (C1)
type(lhs) = type(rhs)
。 - (C2)
shape(result) = shape(lhs)
。 - (C3)
element_type(result)
具有類型complex<E>
,其中E = element_type(lhs)
。
範例
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
語義
按照指定引數的順序,將 inputs
沿著 dimension
維度串連,並產生 result
張量。更正式的 result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
,其中:
id = d0 + ... + dk-1 + kd
.d
等於dimension
,而d0
是d
inputs
的維度大小。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1-C6) |
(I2)。 | dimension |
si64 類型的常數 |
(C2)、(C4)、(C6) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C5-C6) |
限制
- (C1)
same(element_type(inputs...))
。 - (C2)
same(shape(inputs...))
,dim(inputs..., dimension)
除外。 - (C3)
0 < size(inputs)
。 - (C4)
0 <= dimension < rank(inputs[0])
。 - (C5)
element_type(result) = element_type(inputs[0])
。 - (C6)
shape(result) = shape(inputs[0])
,但以下項目除外:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
範例
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
常數
語義
從常數 value
產生 output
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | value |
常數 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
output |
張量或量化張量 | (C1)。 |
限制
- (C1)
type(value) = type(output)
。
範例
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
完成轉換
語義
這個外掛程式能在 operand
張量上,從某個元素類型到另一種元素類型轉換,然後產生 result
張量。
如果是 boolean-to-any-supported-type 轉換,系統會將 false
值轉換為零,而 true
值會轉換成一。針對any-supported-type-to-boolean轉換,0 值會轉換為 false
,非零值則會轉換為 true
。請參閱下文瞭解這種做法在複雜類型中的運作方式。
如果轉換包含整數到整數、整數到浮點點或 floating-point-to-floating-point,如果來源值可在目的地類型中精準呈現,結果值就是確切的表示法。否則行為為待定。(#180)。
如果轉換涉及floating-point-to-integer,小數部分會遭到截斷。如果目的地類型無法呈現截斷的值,則行為為未定 (#180)。
如果轉換涉及複雜至複雜,其行為與浮點對浮點點轉換行為相同,可用於轉換實際部分和虛部分。
針對complex-to-any-other-type及 complex-to-any-other-type 轉換,系統會忽略來源虛值,或目的地虛值的值為零。實際轉換次數會跟著浮點轉換
原則上,這個運算可以表示去量化 (從量化張量轉換為一般張量)、量化 (從一般張量轉換為量化張量) 和重新量化 (量化張量之間的轉換),但目前我們針對第一個用途有 uniform_dequantize
的專屬運算,第二和第三種用途則為 uniform_quantize
。這兩個運算日後可能會合併為 convert
(#1576)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量 | (C1)。 |
限制
- (C1)
shape(operand) = shape(result)
。
範例
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
卷積
語義
計算 lhs
視窗和 rhs
配量之間的點積,並產生 result
。下圖顯示如何使用具體範例,從 lhs
和 rhs
計算 result
中的元素。
更正式的做法是以下列為 lhs
修正輸入內容,以便表達 lhs
的區間:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
這項參照使用下列輔助函式:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
,即j[d] = i[permutation[d]]
。
如果為 feature_group_count = 1
和 batch_group_count = 1
,則針對 index_space(dim(result, output_spatial_dimensions...))
中的所有 output_spatial_index
,result[result_shape(:, output_spatial_index, :)] = dot_product
其中:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
。這項功能似乎未使用,因此我們預計在日後將其移除 (#1181)。dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
如為 feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
如為 batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
。
如果是量化類型,請執行 dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1)、(C10-C11)、(C14) (C25)、(C27-C30) |
(I2)。 | rhs |
張量或量化張量 | (C1)、(C14-C16)、(C25)、(C27-C32) |
(I3)。 | window_strides |
si64 類型的 1 維張量常數 |
(C2-C3)、(C25) |
(I4)。 | padding |
si64 類型的 2 維張量常數 |
(C4)、(C25) |
(I5) | lhs_dilation |
si64 類型的 1 維張量常數 |
(C5-C6)、(C25) |
(I6)。 | rhs_dilation |
si64 類型的 1 維張量常數 |
(C7-C8)、(C25) |
(I7) | window_reversal |
i1 類型的 1 維張量常數 |
(C9) |
(I8) | input_batch_dimension |
si64 類型的常數 |
(C10)、(C13)、(C25) |
(I9) | input_feature_dimension |
si64 類型的常數 |
(C11)、(C13-C14) |
(I10)。 | input_spatial_dimensions |
si64 類型的 1 維張量常數 |
(C12)、(C13)、(C25) |
(I11)。 | kernel_input_feature_dimension |
si64 類型的常數 |
(C14)、(C18) |
(I12)。 | kernel_output_feature_dimension |
si64 類型的常數 |
(C15-C16)、(C18)、(C25)、(C32) |
(I13)。 | kernel_spatial_dimensions |
si64 類型的 1 維張量常數 |
(C17-C18)、(C25) |
(I14)。 | output_batch_dimension |
si64 類型的常數 |
(C20)、(C25) |
(I15) | output_feature_dimension |
si64 類型的常數 |
(C20)、(C25)、(C33) |
(I16)。 | output_spatial_dimensions |
si64 類型的 1 維張量常數 |
(C19-C20)、(C25) |
(I17)。 | feature_group_count |
si64 類型的常數 |
(C11)、(C14)、(C16)、(C21)、(C23) |
(I18) | batch_group_count |
si64 類型的常數 |
(C10)、(C15)、(C22)、(C23)、(C25) |
(I19) | precision_config |
DEFAULT 、HIGH 和 HIGHEST 的列舉數量 |
(C24) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或量化張量 | (C25-C28)、(C30-C31)、(C33) |
限制
- (C1)
N = rank(lhs) = rank(rhs)
。 - (C2)
size(window_strides) = N - 2
。 - (C3)
0 < window_strides
。 - (C4)
shape(padding) = [N - 2, 2]
。 - (C5)
size(lhs_dilation) = N - 2
。 - (C6)
0 < lhs_dilation
。 - (C7)
size(rhs_dilation) = N - 2
。 - (C8)
0 < rhs_dilation
。 - (C9)
size(window_reversal) = N - 2
。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
。 - (C12)
size(input_spatial_dimensions) = N - 2
。 - (C13) 為
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
。 - (C17)
size(kernel_spatial_dimensions) = N - 2
。 - (C18) 為
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
。 - (C20) 為
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
。 - (C22)
0 < batch_group_count
。 - (C23)
feature_group_count = 1 or batch_group_count = 1
。 - (C24)
size(precision_config) = 2
。 - (C25)
dim(result, result_dim)
定義為:- 如果
result_dim = output_batch_dimension
則為dim(lhs, input_batch_dimension) / batch_group_count
。 - 如果
result_dim = output_feature_dimension
則為dim(rhs, kernel_output_feature_dimension)
。 num_windows
否則,其中:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- 如果
- (C26)
rank(result) = N
。 - 如果作業使用非量化張量:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
。
- (C27)
- 如果作業使用量化張量:
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
。 - (C29)
storage_type(lhs) = storage_type(rhs)
。 - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C31) 如為
is_per_tensor_quantized(rhs)
,則值為is_per_tensor_quantized(result)
。 - (C32) 如為
is_per_axis_quantized(rhs)
,則值為quantization_dimension(rhs) = kernel_output_feature_dimension
。 - (C33) 如為
is_per_axis_quantized(result)
,則值為quantization_dimension(result) = output_feature_dimension
。
- (C28)
範例
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
餘弦
語義
對 operand
張量執行元素相關餘弦運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
cos
。 - 複數:複數的餘弦。
- 量化類型:
dequantize_op_quantize(cosine, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
語義
針對 operand
張量中的前置零位元執行元素相關計數,並產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數類型的張量 | (C1)。 |
限制
- (C1)
type(operand) = type(result)
。
範例
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
語義
封裝實作定義作業 call_target_name
,該作業採用 inputs
和 called_computations
並產生 results
。has_side_effect
、backend_config
和 api_version
可用於提供其他的實作定義中繼資料。
目前,這項作業包含一系列結構相當多元的中繼資料集合,反映了其在 XLA 編譯器中對應作業的自然演進。我們計劃日後統合這類中繼資料。(#741)。
輸入內容
標籤 | 名稱 | 類型 |
---|---|---|
(I1)。 | inputs |
變異數值 |
(I2)。 | call_target_name |
string 類型的常數 |
(I3)。 | has_side_effect |
i1 類型的常數 |
(I4)。 | backend_config |
string 類型的常數 |
(I5) | api_version |
si32 類型的常數 |
(I6)。 | called_computations |
string 類型的常數數 |
輸出
名稱 | 類型 |
---|---|
results |
變異數值 |
範例
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
除
語義
執行被除 lhs
和除數 rhs
張量的元素相關除法,並產生 result
張量。根據元素類型執行以下操作:
- 整數:整數除法,產生捨棄任何小數部分的代數商。
- 浮點值:由 IEEE-754 距離
division
。 - 複數:複數。
- 量化類型:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
(I2)。 | rhs |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
語義
計算 lhs
切片和 rhs
配量之間的內積,並產生 result
張量。
更正式的 result[result_index] = dot_product
,其中:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
其中size(result_batching_index) = size(lhs_batching_dimensions)
、size(result_lhs_index) = size(lhs_result_dimensions)
和size(result_rhs_index) = size(rhs_result_dimensions)
。transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
。
如果是量化類型,請執行 dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
。
這個方法只會指定每個張量量化的語意。每軸量化目前正在處理中 (#1574)。此外,我們日後可能會考慮新增對混合量化支援 (#1575)。
precision_config
可控制加速器後端運算速度與準確率之間的取捨。可能的值如下 (於當下,這些列舉值的語意尚未指定,但我們計劃在 #755 中解決這個問題):
DEFAULT
:最快計算,但近似原始數字的精確度較不精確。HIGH
:計算速度較慢,但近似原始數字的估計值會更準確。HIGHEST
:計算速度最慢,但近似原始數字最準確的預測結果。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C5-C6)、(C9-C10)、(C12-C16) |
(I2)。 | rhs |
張量或每個張量化張量 | (C7-C10)、(C12) |
(I3)。 | lhs_batching_dimensions |
si64 類型的 1 維張量常數 |
(C1)、(C3)、(C5)、(C9)、(C12) |
(I4)。 | rhs_batching_dimensions |
si64 類型的 1 維張量常數 |
(C1)、(C4)、(C7)、(C9) |
(I5) | lhs_contracting_dimensions |
si64 類型的 1 維張量常數 |
(C2)、(C3)、(C6)、(C10) |
(I6)。 | rhs_contracting_dimensions |
si64 類型的 1 維張量常數 |
(C2)、(C4)、(C8)、(C10) |
(I7) | precision_config |
DEFAULT 、HIGH 和 HIGHEST 的列舉數量 |
(C11) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C12)、(C14)、(C16) |
限制
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
。 - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
。 - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
。 - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
。 - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
。 - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
。 - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
。 - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
。 - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
。 - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
。 - (C11)
size(precision_config) = 2
。 - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
。 - 如果作業使用非量化張量:
- (C13)
element_type(lhs) = element_type(rhs)
。
- (C13)
- 如果作業使用量化張量:
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
。 - (C15)
storage_type(lhs) = storage_type(rhs)
。 - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C17)
zero_points(rhs) = 0
。
- (C14)
範例
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
語義
使用動態計算的啟動索引從 operand
擷取切片,然後產生 result
張量。start_indices
包含每個維度的切片起始索引,slice_sizes
包含每個維度的切片大小。更正式的 result[result_index] = operand[operand_index]
,其中:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C2)、(C4) |
(I2)。 | start_indices |
整數類型的 0 維張量變異數 | (C2)、(C3) |
(I3)。 | slice_sizes |
si64 類型的 1 維張量常數 |
(C2)、(C4)、(C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)、(C5) |
限制
- (C1)
element_type(operand) = element_type(result)
。 - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
。 - (C3)
same(type(start_indices...))
。 - (C4)
0 <= slice_sizes <= shape(operand)
。 - (C5)
shape(result) = slice_sizes
。
範例
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
語義
產生等於 operand
張量的 result
張量,但從 start_indices
開始的切片會以 update
中的值更新。更正式的,result[result_index]
定義為:
- 如果
0 <= update_index < shape(update)
且符合以下條件,則為update[update_index]
:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- 否則傳回
operand[result_index]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1-C4)、(C6) |
(I2)。 | update |
張量或每個張量化張量 | (C2)、(C3)、(C6) |
(I3)。 | start_indices |
整數類型的 0 維張量變異數 | (C4)、(C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
type(operand) = type(result)
。 - (C2)
element_type(update) = element_type(operand)
。 - (C3)
rank(update) = rank(operand)
。 - (C4)
size(start_indices) = rank(operand)
。 - (C5)
same(type(start_indices...))
。 - (C6)
0 <= shape(update) <= shape(operand)
。
範例
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
指數
語義
對 operand
張量執行元素相關指數運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
exp
。 - 複數:複數的指數。
- 量化類型:
dequantize_op_quantize(exponential, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
語義
對 operand
張量執行元素指數減去一次作業,然後產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
expm1
。 - 複數:複數的指數減 1。
- 量化類型:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
語義
針對真實複雜的輸入/輸出執行向前和反向 Fourier 轉換。
fft_type
是下列其中一項:
FFT
:將複雜到複雜的 FFT。IFFT
:複數到複雜 FFT。RFFT
:正向複雜 FFT。IRFFT
:非實際為複雜 FFT (即需要複雜、傳回真實)。
更正式的說,由於 fft
函式會取用複雜類型的 1 維張量做為輸入,會產生與輸出相同類型的 1 維張量,並計算離散的 Fourier 轉換:
如果是 fft_type = FFT
,result
是一系列 L 運算的最終結果,其中 L = size(fft_length)
。例如,L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
此外,由於 ifft
函式的類型簽章相同,且會計算 fft
的反向運算:
對於 fft_type = IFFT
,result
定義為 fft_type = FFT
運算的反函式。例如,L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
此外,由於 rfft
函式會取用浮點類型的 1 維張量,因此會產生相同浮點語意的 1 維複雜張量,且運作方式如下:
rfft(real_operand) = truncated_result
,其中complex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(為實際運算元計算獨立的 Fourier 轉換時,結果的第一個 N/2 + 1
元素會明確定義其餘結果,因此系統會截斷 rfft
的結果,避免計算多餘的元素。
如果是 fft_type = RFFT
,result
是一系列 L 運算的最終結果,其中 L = size(fft_length)
。例如,L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
最後,假設 irfft
函式的類型簽章相同,且會計算 rfft
的反向運算:
對於 fft_type = IRFFT
,result
定義為 fft_type = RFFT
運算的反函式。例如,L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點或複雜類型的張量 | (C1)、(C2)、(C4)、(C5) |
(I2)。 | fft_type |
FFT 、IFFT 、RFFT 和 IRFFT 的列舉 |
(C2)、(C5) |
(I3)。 | fft_length |
si64 類型的 1 維張量常數 |
(C1)、(C3)、(C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點或複雜類型的張量 | (C2)、(C4)、(C5) |
限制
- (C1)
size(fft_length) <= rank(operand)
。 - (C2)
operand
和result
元素類型之間的關係如下:- 如果
fft_type = FFT
、element_type(operand)
和element_type(result)
的類型相同, - 如果
fft_type = IFFT
、element_type(operand)
和element_type(result)
的類型相同, - 如果
fft_type = RFFT
,element_type(operand)
是浮點類型,element_type(result)
是相同浮點語意的複雜類型。 - 如果
fft_type = IRFFT
,element_type(operand)
是複雜類型,element_type(result)
是相同浮點語意的浮點類型。
- 如果
- (C3)
1 <= size(fft_length) <= 3
。 - (C4) 如果介於
operand
和result
之間,則會有浮點類型的張量real
,然後是shape(real)[-size(fft_length):] = fft_length
。 - (C5)
shape(result) = shape(operand)
,但以下項目除外:- 如果值為
fft_type = RFFT
,則為dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
。 - 如果值為
fft_type = IRFFT
,則為dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
。
- 如果值為
範例
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
語義
執行 operand
張量的元素樓層,並產生 result
張量。從 IEEE-754 規格實作 roundToIntegralTowardNegative
作業。如果是量化類型,請執行 dequantize_op_quantize(floor, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
收集
語義
從 start_indices
中指定的偏移值收集 operand
張量,並產生 result
張量。
下圖以具體範例說明 result
中的元素如何對應到 operand
中的元素。圖中挑選了一些 result
索引範例,並詳細說明它們對應的 operand
索引。
更正式的 result[result_index] = operand[operand_index]
,其中:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
定義為:start_indices[bi0, ..., :, ..., biN]
,其中bi
是batch_index
中的個別元素,如果index_vector_dim
<rank(start_indices)
,則:
會插入index_vector_dim
索引。- 否則傳回
[start_indices[batch_index]]
。
- 針對
axes(operand)
中的d_operand
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
如果為d_operand = start_index_map[d_start]
。- 否則傳回
full_start_index[d_operand] = 0
。
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
,其中oi
是offset_index
中的個別元素,而0
會插入collapsed_slice_dims
的索引。operand_index = full_start_index + full_offset_index
。
如果 indices_are_sorted
是 true
,實作會假設 start_indices
會依 start_index_map
排序,否則行為未定義。更正式地是,針對 indices(result)
、full_start_index(i1) <= full_start_index(i2)
的所有 i1 < i2
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C7)、(C10-C12)、(C14) |
(I2)。 | start_indices |
整數類型的張量 | (C2)、(C3)、(C13) |
(I3)。 | offset_dims |
si64 類型的 1 維張量常數 |
(C1)、(C4-C5)、(C13) |
(I4)。 | collapsed_slice_dims |
si64 類型的 1 維張量常數 |
(C1)、(C6-C8)、(C13) |
(I5) | start_index_map |
si64 類型的 1 維張量常數 |
(C3)、(C9)、(C10) |
(I6)。 | index_vector_dim |
si64 類型的常數 |
(C2)、(C3)、(C13) |
(I7) | slice_sizes |
si64 類型的 1 維張量常數 |
(C8)、(C11-C13) |
(I8) | indices_are_sorted |
i1 類型的常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C5)、(C13-C14) |
限制
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
。 - (C2)
0 <= index_vector_dim <= rank(start_indices)
。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
。 - (C5)
0 <= offset_dims < rank(result)
。 - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
。 - (C7)
0 <= collapsed_slice_dims < rank(operand)
。 - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
。 - (C9)
is_unique(start_index_map)
。 - (C10)
0 <= start_index_map < rank(operand)
。 - (C11)
size(slice_sizes) = rank(operand)
。 - (C12)
0 <= slice_sizes <= shape(operand)
。 - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
,其中:batch_dim_sizes = shape(start_indices)
,但資料不會包含與index_vector_dim
對應的start_indices
尺寸大小。offset_dim_sizes = shape(slice_sizes)
,但不會包含slice_sizes
中與collapsed_slice_dims
對應的尺寸大小。combine
會將batch_dim_sizes
置於與batch_dims
和offset_dim_sizes
對應的軸上,並放在與offset_dims
相對應的軸。
- (C14)
element_type(operand) = element_type(result)
。
範例
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
語義
產生 operand
的指定 dimension
大小。更正式的為 result = dim(operand, dimension)
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量 | (C1)。 |
(I2)。 | dimension |
si64 類型的常數 |
(C1)。 |
輸出
名稱 | 類型 |
---|---|
result |
si32 類型的 0 維張量 |
限制
- (C1)
0 <= dimension < rank(operand)
。
範例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
語義
擷取 operand
元組 index
位置的元素,並產生 result
。正式的是 result = operand[index]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
tuple | (C1)、(C2) |
(I2)。 | index |
si32 類型的常數 |
(C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
任何支援的類型 | (C2) |
限制
- (C1)
0 <= index < size(operand)
。 - (C2)
type(result) = tuple_element_types(operand)[index]
。
範例
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
語義
根據 pred
的值,從 true_branch
或 false_branch
僅執行一個函式產生輸出內容。正式的是 result =
pred ? true_branch() : false_branch()
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | pred |
i1 類型的 0 維張量 |
|
(I2)。 | true_branch |
函式 | (C1-C3) |
(I3)。 | false_branch |
函式 | (C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量、量化張量或符記 | (C3) |
限制
- (C1)
input_types(true_branch) = input_types(false_branch) = []
。 - (C2)
output_types(true_branch) = output_types(false_branch)
。 - (C3)
type(results...) = output_types(true_branch)
。
範例
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
想像力
語義
從 operand
擷取元素部分,然後產生 result
張量。更正式的做法是每個元素 x
:imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點或複雜類型的張量 | (C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型的張量 | (C1)、(C2) |
限制
- (C1)
shape(result) = shape(operand)
。 - (C2)
element_type(result)
定義為:- 如果
is_complex(operand)
則為complex_element_type(element_type(operand))
。 - 否則傳回
element_type(operand)
。
- 如果
範例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
動態內廣告
語義
這個外掛程式能讀取動態饋給中的資料,並產生 results
。
infeed_config
的語意是實作定義。
results
包含最先出現的酬載值和最後的符記。我們日後打算將酬載和權杖分成兩個不同的輸出內容,以便清楚呈現。(#670)。
輸入內容
標籤 | 名稱 | 類型 |
---|---|---|
(I1)。 | token |
token |
(I2)。 | infeed_config |
string 類型的常數 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量、量化張量或符記 | (C1-C3) |
限制
- (C1)
0 < size(results)
。 - (C2)
is_empty(result[:-1])
或is_tensor(type(results[:-1]))
。 - (C3)
is_token(type(results[-1]))
。
範例
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
語義
為 output
張量填入 iota_dimension
維度,從零開始遞增的值。更正式的
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | iota_dimension |
si64 |
(C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
output |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
限制
- (C1)
0 <= iota_dimension < rank(output)
。
範例
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
語義
根據元素執行元素檢查 x
中的值是否為有限 (即不是 +Inf、-Inf 或 NaN),並產生 y
張量。根據 IEEE-754 規格實作 isFinite
作業。如果是量化類型,結果一律為 true
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | x |
浮點類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
y |
布林類型的張量 | (C1)。 |
限制
- (C1)
shape(x) = shape(y)
。
範例
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
紀錄/記錄檔
語義
對 operand
張量執行元素相關對數運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
log
。 - 複數:複數。
- 量化類型:
dequantize_op_quantize(log, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
語義
對 operand
張量執行元素相關對數加上一項運算,然後產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
logp1
。 - 複數:複數的對數加上一。
- 量化類型:
dequantize_op_quantize(log_plus_one, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
物流
語義
對 operand
張量執行元素相關邏輯運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
division(1, addition(1, exp(-x)))
。 - 複數:複雜的邏輯。
- 量化類型:
dequantize_op_quantize(logistic, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
地圖
語義
沿著 dimensions
將對應函式 computation
套用至 inputs
,並建立 result
張量。
正式是 result[result_index] = computation(inputs...[result_index])
。請注意,dimensions
目前並未使用,未來可能會移除 (#487)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1-C4) |
(I2)。 | dimensions |
si64 類型的 1 維張量常數 |
(C3) |
(I3)。 | computation |
函式 | (C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)、(C4) |
限制
- (C1)
shape(inputs...) = shape(result)
。 - (C2)
0 < size(inputs) = N
。 - (C3)
dimensions = range(rank(inputs[0]))
。 - (C4)
computation
具有(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
類型,其中Ei = element_type(inputs[i])
和E' = element_type(result)
。
範例
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
最高
語義
對張量 lhs
和 rhs
執行元素相關最大值作業,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 OR。
- 整數:整數最大值。
- 浮點值:由 IEEE-754 距離
maximum
。 - 複數:
(real, imaginary)
組合的字母順序上限。在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。 - 量化類型:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1)。 |
(I2)。 | rhs |
張量或每個張量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
最低
語義
對張量 lhs
和 rhs
執行元素相關最低作業,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 AND。
- 整數:整數的最小值。
- 浮點值:由 IEEE-754 距離
minimum
。 - 複數:
(real, imaginary)
組合的字母順序下限。在複雜數字上排序複雜的語意會帶來令人意想不到的語意,因此我們未來計劃移除對此運算複雜數字的支援 (#560)。 - 量化類型:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1)。 |
(I2)。 | rhs |
張量或每個張量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
乘
語義
執行兩個張量 lhs
和 rhs
的元素相關乘積,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 AND。
- 整數:整數乘法。
- 浮點值:由 IEEE-754 距離
multiplication
。 - 複數:複雜的乘法。
- 量化類型:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
張量或每個張量化張量 | (C1)。 |
(I2)。 | rhs |
張量或每個張量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
語義
執行 operand
張量的元素否定,並產生 result
張量。根據元素類型執行以下操作:
- 帶正負號整數:整數否定。
- 無正負號整數:bitcast 到帶正負號整數、整數否定,位元轉換回無正負號整數。
- 浮點值:由 IEEE-754 距離
negate
。 - 複數:複雜的否定。
- 量化類型:
dequantize_op_quantize(negate, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
不是
語義
於「不」執行張量 operand
的元素時,產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 NOT。
- 整數:位元 NOT。
引數
名稱 | 類型 | 限制 |
---|---|---|
operand |
布林值或整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
布林值或整數類型的張量 | (C1)。 |
限制
- (C1)
type(operand) = type(result)
。
範例
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
語義
確保產生 operand
的作業會在依附 result
的任何作業之前執行,並防止編譯器轉換作業在障礙之間移動作業。除此之外,作業則是身分 (例如 result = operand
)。
引數
名稱 | 類型 | 限制 |
---|---|---|
operand |
各種張量、各張量化張量或符記 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
各種張量、各張量化張量或符記 | (C1)。 |
限制
- (C1)
type(operand...) = type(result...)
。
範例
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
或
語義
以元素或 rhs
的方式執行兩個張量 lhs
和 rhs
,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 OR。
- 整數:位元 OR。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數或布林值類型的張量 | (C1)。 |
(I2)。 | rhs |
整數或布林值類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數或布林值類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
外動態
語義
將 inputs
寫入外動態饋給,並產生 result
權杖。
outfeed_config
的語意是實作定義。
輸入內容
標籤 | 名稱 | 類型 |
---|---|---|
(I1)。 | inputs |
各種張量或量化張量 |
(I2)。 | token |
token |
(I3)。 | outfeed_config |
string 類型的常數 |
輸出
名稱 | 類型 |
---|---|
result |
token |
範例
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
墊片
語義
透過在張量周圍以及具有指定的 padding_value
的張量元素之間加上邊框間距來展開 operand
。
edge_padding_low
和 edge_padding_high
分別指定在低端 (索引 0 旁邊) 和頂端 (最高索引旁邊) 新增的邊框間距量。邊框間距量可以是負數,其中負邊框間距的絕對值代表要從指定維度中移除的元素數量。
interior_padding
會指定在每個維度的任意兩個元素之間加入的邊框間距量,但不得為負數。內部邊框間距發生在邊緣邊框間距之前,因此負的邊緣邊框間距會從內部填充運算元中移除元素。
更正式的,result[result_index]
定義為:
- 如果為
result_index = edge_padding_low + operand_index * (interior_padding + 1)
,則為operand[operand_index]
。 - 否則傳回
padding_value
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C2)、(C4) |
(I2)。 | padding_value |
0 維張量或按張量化張量 | (C1)。 |
(I3)。 | edge_padding_low |
si64 類型的 1 維張量常數 |
(C1)、(C4) |
(I4)。 | edge_padding_high |
si64 類型的 1 維張量常數 |
(C1)、(C4) |
(I5) | interior_padding |
si64 類型的 1 維張量常數 |
(C2-C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C3-C6) |
限制
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
。 - (C3)
0 <= interior_padding
。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
。
範例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
語義
產生目前程序的 partition_id
。
輸出
名稱 | 類型 |
---|---|
result |
ui32 類型的 0 維張量 |
範例
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
罌粟紅
語義
針對 operand
張量中設定的位元數執行元素相關位元數,並產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數類型的張量 | (C1)。 |
限制
- (C1)
type(operand) = type(result)
。
範例
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
功率
語義
透過 rhs
張量執行 lhs
張量的元素指數,並產生 result
張量。根據元素類型執行以下操作:
- 整數:整數指數。
- 浮點值:由 IEEE-754 距離
pow
。 - 複數:複數的指數。
- 量化類型:
dequantize_op_quantize(power, lhs, rhs, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
(I2)。 | rhs |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
語義
從 operand
逐元素擷取實際部分,並產生 result
張量。更正式的做法是每個元素 x
:real(x) = is_complex(x) ? real_part(x) : x
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點或複雜類型的張量 | (C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型的張量 | (C1)、(C2) |
限制
- (C1)
shape(result) = shape(operand)
。 - (C2)
element_type(result)
定義為:- 如果
is_complex(operand)
則為complex_element_type(element_type(operand))
。 - 否則傳回
element_type(operand)
。
- 如果
範例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
Recv
語義
接收含有 channel_id
的頻道資料,並產生 results
。
如果 is_host_transfer
為 true
,則作業會從主機移轉資料。否則,系統會轉移其他裝置中的資料。這代表的是實作定義。這個旗標會複製 channel_type
中提供的資訊,因此未來我們打算只會保留其中一個 (#666)。
results
包含最先出現的酬載值和最後的符記。我們日後打算將酬載和權杖分成兩個不同的輸出內容,以便清楚呈現。(#670)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | token |
token |
(C4) |
(I2)。 | channel_id |
si64 類型的常數 |
|
(I3)。 | channel_type |
DEVICE_TO_DEVICE 和 HOST_TO_DEVICE 的列舉項目 |
(C1)。 |
(I4)。 | is_host_transfer |
i1 類型的常數 |
(C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量、量化張量或符記 | (C2-C4) |
限制
- (C1)
channel_type
定義為:HOST_TO_DEVICE
表示is_host_transfer = true
,- 否則傳回
DEVICE_TO_DEVICE
。
- (C2)
0 < size(results)
。 - (C3)
is_empty(result[:-1])
或is_tensor(type(results[:-1]))
。 - (C4)
is_token(type(results[-1]))
。
範例
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
減少
語義
將縮減函式 body
沿著 dimensions
套用至 inputs
和 init_values
,並產生 results
張張。
縮減順序是由實作定義,這表示 body
和 init_values
必須形成單聲道屬性,以確保該作業在所有實作項目上,都會產生相同的結果。然而,這項條件並不適用於許多常見的縮減項目。例如,body
的浮點加至 init_values
與 0 實際上並未形成單聲道,因為浮點加上沒有關聯。
更正式的 results...[j0, ..., jR-1] = reduce(input_slices_converted)
,其中:
input_slices = inputs...[j0, ..., :, ..., jR-1]
,其中:
會在dimensions
插入。input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
代表二進位樹狀結構schedule
,其中:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
是實作定義的完整二進位樹狀結構,其按順序週遊包含:input_slices_converted...[index]
值,適用於index_space(input_slices_converted)
中所有index
,並依index
的字母順序排列。- 在實作定義的位置,會與實作定義的
init_values_converted
數量交錯。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1-C4)、(C6)、(C7) |
(I2)。 | init_values |
各種 0 維張量或每個張量量化張量的數量 | (C2)、(C3) |
(I3)。 | dimensions |
si64 類型的 1 維張量常數 |
(C4)、(C5)、(C7) |
(I4)。 | body |
函式 | (C6) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量或每個張量量化張量 | (C3)、(C7)、(C8) |
限制
- (C1)
same(shape(inputs...))
。 - (C2)
element_type(inputs...) = element_type(init_values...)
。 - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
。 - (C4)
0 <= dimensions < rank(inputs[0])
。 - (C5)
is_unique(dimensions)
。 - (C6)
body
具有(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
類型,其中is_promotable(element_type(inputs[i]), Ei)
。 - (C7)
shape(results...) = shape(inputs...)
,但不含與dimensions
對應的inputs...
尺寸大小。 - (C8)
[0,N)
中所有i
的element_type(results[i]) = Ei
。
範例
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
語義
將 operand
的元素導向元素轉換,使用 exponent_bits
和 mantissa_bits
並回到原始浮點類型,然後產生 output
張量。
更正式:
- 原始值的 mantisa 位元會更新,使用
mantissa_bits
語意將原始值四捨五入至最接近的值。 :roundToIntegralTiesToEven
- 然後,如果
mantissa_bits
小於原始值的 mantisa 位元數,系統會將 mantisa 位元截斷為mantissa_bits
。 - 接著,如果中繼結果的指數位元不符合
exponent_bits
提供的範圍,則中繼結果會使用原始記號溢位至無限,或使用原始符號後溢位至零。 - 如果是量化類型,請執行
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
(I2)。 | exponent_bits |
si32 類型的常數 |
(C2) |
(I3)。 | mantissa_bits |
si32 類型的常數 |
(C3) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
output |
浮點類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(output)
。 - (C2)
1 <= exponent_bits
。 - (C3)
0 <= mantissa_bits
。
範例
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
語義
在 StableHLO 程序格線中的每個程序群組中,使用 computations
對每個程序的 operand
張量值執行縮減,將縮減結果沿著 scatter_dimension
分割成部分,並將分割部分分散到不同程序之間以產生 result
。
這項作業會將 StableHLO 程序格線分割為 process_groups
,定義如下:
cross_replica(replica_groups)
如果為channel_id <= 0 and use_global_device_ids = false
。cross_replica_and_partition(replica_groups)
如果為channel_id > 0 and use_global_device_ids = false
。flattened_ids(replica_groups)
如果為channel_id > 0 and use_global_device_ids = true
。
之後,在每個 process_group
中:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.- 為
process_group
中所有sender
的result@receiver = parts@sender[receiver_index]
,其中receiver_index = process_group.index(receiver)
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C2)、(C7)、(C8) |
(I2)。 | scatter_dimension |
si64 類型的常數 |
(C1)、(C2)、(C8) |
(I3)。 | replica_groups |
si64 類型的 2 維張量常數 |
(C3-C5) |
(I4)。 | channel_id |
si64 類型的常數 |
(C6) |
(I5) | use_global_device_ids |
i1 類型的常數 |
(C6) |
(I6)。 | computation |
函式 | (C7) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C8-C9) |
限制
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
。 - (C2)
0 <= scatter_dimension < rank(operand)
。 - (C3)
is_unique(replica_groups)
。 - (C4)
size(replica_groups)
定義為:- 如果使用
cross_replica
,則為num_replicas
。 - 如果使用
cross_replica_and_partition
,則為num_replicas
。 - 如果使用
flattened_ids
,則為num_processes
。
- 如果使用
- (C5)
0 <= replica_groups < size(replica_groups)
。 - (C6) 如為
use_global_device_ids = true
,則預設為channel_id > 0
。 - (C7)
computation
具有類型(tensor<E>, tensor<E>) -> (tensor<E>)
,其中is_promotable(element_type(operand), E)
。 - (C8)
shape(result) = shape(operand)
,但下列項目除外:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
。
範例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
語義
將縮減函式 body
套用至 inputs
和 init_values
的視窗,並產生 results
。
下圖以具體範例說明系統如何從 inputs...
計算 results...
中的元素。
正式來說,results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(請參閱 reduce) 其中:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15) |
(I2)。 | init_values |
各種 0 維張量或每個張量量化張量的數量 | (C1)、(C13) |
(I3)。 | window_dimensions |
si64 類型的 1 維張量常數 |
(C4)、(C5)、(C15) |
(I4)。 | window_strides |
si64 類型的 1 維張量常數 |
(C6)、(C7)、(C15) |
(I5) | base_dilations |
si64 類型的 1 維張量常數 |
(C8)、(C9)、(C15) |
(I6)。 | window_dilations |
si64 類型的 1 維張量常數 |
(C10)、(C11)、(C15) |
(I7) | padding |
si64 類型的 2 維張量常數 |
(C12)、(C15) |
(I8) | body |
函式 | (C13) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量或每個張量量化張量 | (C1)、(C14-C16) |
限制
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
。 - (C2)
same(shape(inputs...))
。 - (C3)
element_type(inputs...) = element_type(init_values...)
。 - (C4)
size(window_dimensions) = rank(inputs[0])
。 - (C5)
0 < window_dimensions
。 - (C6)
size(window_strides) = rank(inputs[0])
。 - (C7)
0 < window_strides
。 - (C8)
size(base_dilations) = rank(inputs[0])
。 - (C9)
0 < base_dilations
。 - (C10)
size(window_dilations) = rank(inputs[0])
。 - (C11)
0 < window_dilations
。 - (C12)
shape(padding) = [rank(inputs[0]), 2]
。 - (C13)
body
具有(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
類型,其中is_promotable(element_type(inputs[i]), Ei)
。 - (C14)
same(shape(results...))
。 - (C15)
shape(results[0]) = num_windows
,其中:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
適用於[0,N)
中的所有i
。
範例
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
剩餘
語義
執行被除 lhs
和除數 rhs
張量的元素相關餘數,然後產生 result
張量。
更正式的,表示結果的正負號從被除數取下,而結果的絕對值一律小於除數的絕對值。餘數的計算方式為 lhs - d * rhs
,其中 d
為:
- 整數:
stablehlo.divide(lhs, rhs)
。 - 浮點值:來自 IEEE-754 的
division(lhs, rhs)
,帶有捨入屬性roundTowardZero
。 - 如為複數:未定 (#997)。
- 量化類型:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
就浮點元素類型而言,這個運算與 IEEE-754 規格的 remainder
運算有差異,其中 d
是最接近 lhs/rhs
值的整數值,而且相對值相等。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
(I2)。 | rhs |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、浮點、複雜類型,或每個張量的量化張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
語義
產生目前程序的 replica_id
。
輸出
名稱 | 類型 |
---|---|
result |
ui32 類型的 0 維張量 |
範例
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
重塑
語義
執行將 operand
張量重新塑形為 result
張量。從概念上來說,最好維持相同的標準表示法,但可能需要改變形狀,例如從 tensor<2x3xf32>
變更為 tensor<3x2xf32>
或 tensor<6xf32>
。
一般而言,result[result_index] = operand[operand_index]
,其中 result_index
和 operand_index
按照 index_space(result)
和 index_space(operand)
的字母順序排列相同的位置。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或量化張量 | (C1-C3) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或量化張量 | (C1-C3) |
限制
- (C1)
element_type(result)
的提供者:element_type(operand)
(如果為!is_per_axis_quantized(operand)
)。element_type(operand)
除外,但quantization_dimension(operand)
和quantization_dimension(result)
可能不同。
- (C2)
size(operand) = size(result)
。 - (C3) 如為
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
範例
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
反向排序
語義
依照指定的 dimensions
反轉 operand
中的元素順序,並產生 result
張量。更正式的 result[result_index] = operand[operand_index]
,其中:
- 如果
dimensions
中的d
,則為operand_index[d] = dim(result, d) - result_index[d] - 1
。 - 否則傳回
operand_index[d] = result_index[d]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1)、(C3) |
(I2)。 | dimensions |
si64 類型的 1 維張量常數 |
(C2)、(C3) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)、(C3) |
限制
- (C1)
type(operand) = type(result)
。 - (C2)
is_unique(dimensions)
。 - (C3)
0 <= dimensions < rank(result)
。
範例
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
奈米
語義
使用 rng_distribution
演算法產生隨機號碼,並產生指定形狀 shape
的 result
張量。
如果為 rng_distribution = UNIFORM
,則隨機數會根據間隔 [a, b)
的統一分佈產生。如果為 a >= b
,表示行為未定義。
如果為 rng_distribution = NORMAL
,則隨機數字會在常態分佈後產生,平均值為 a
,標準差 = b
。如果為 b < 0
,表示行為未定義。
隨機號碼的產生方式確切的產生方式是由實作定義。例如,不一定有確定性,而且不一定使用隱藏狀態。
在與許多利害關係人的對話中,這項運算實際上已確實淘汰,因此日後我們計劃設法移除這個功能 (#597)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | a |
整數、布林值或浮點類型的 0 維張量 | (C1)、(C2) |
(I2)。 | b |
整數、布林值或浮點類型的 0 維張量 | (C1)、(C2) |
(I3)。 | shape |
si64 類型的 1 維張量常數 |
(C3) |
(I4)。 | rng_distribution |
UNIFORM 和 NORMAL 的列舉項目 |
(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、布林值或浮點類型的張量 | (C1-C3) |
限制
- (C1)
element_type(a) = element_type(b) = element_type(result)
。 - (C2) 如為
rng_distribution = NORMAL
,則預設為is_float(a)
。 - (C3)
shape(result) = shape
。
範例
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
語義
使用虛擬隨機號碼產生器演算法 rng_algorithm
,為初始狀態 initial_state
傳回已填入統一隨機位元的 output
,並更新輸出狀態 output_state
。保證提供 initial_state
的確定性函式,但不保證在實作之間具確定性。
rng_algorithm
是下列其中一項:
DEFAULT
:導入定義的演算法。THREE_FRY
:導入 Threefry 演算法定義的變化版本*。PHILOX
:實作定義的 Philox 演算法變化版本*。
* 請參閱:Salmon 等人,SC 2011。平行隨機數:和 1、2、3 一樣簡單。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | rng_algorithm |
DEFAULT 、THREE_FRY 和 PHILOX 的列舉項目 |
(C2) |
(I2)。 | initial_state |
ui64 類型的 1 維張量 |
(C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
output_state |
ui64 類型的 1 維張量 |
(C1)。 |
output |
整數或浮點類型的張量 |
限制
- (C1)
type(initial_state) = type(output_state)
。 - (C2)
size(initial_state)
定義為:- 在
rng_algorithm = DEFAULT
時定義。 - 如果
rng_algorithm = THREE_FRY
則為2
。 2
或3
表示rng_algorithm = PHILOX
。
- 在
範例
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
語義
針對 operand
張量執行元素相關進位法,將元素四捨五入至最接近的整數,從零打斷,並產生 result
張量。從 IEEE-754 規格實作 roundToIntegralTiesToAway
作業。如果是量化類型,請執行 dequantize_op_quantize(round_nearest_afz, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
語義
針對 operand
張量執行元素相關進位法,將元素四捨五入至最接近的整數,並將鏈結至偶數整數,並產生 result
張。從 IEEE-754 規格實作 roundToIntegralTiesToEven
作業。如果是量化類型,請執行 dequantize_op_quantize(round_nearest_even, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
語義
對 operand
張量執行元素相關平方根運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
rSqrt
。 - 如為複數:複雜的倒數平方根。
- 量化類型:
dequantize_op_quantize(rsqrt, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
語義
產生等於 inputs
張量的 results
張量,但 scatter_indices
指定的幾個切片會使用 update_computation
更新為 updates
的值。
下圖以具體範例說明 updates...
中的元素如何對應到 results...
中的元素。圖中挑選了一些 updates...
索引範例,並詳細說明這些 results...
代表對應的索引。
更正式地是,針對 index_space(updates[0])
中的所有 update_index
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
定義為:scatter_indices[si0, ..., :, ..., siN]
,其中si
是update_scatter_index
中的個別元素,如果index_vector_dim
<rank(scatter_indices)
,則:
會插入index_vector_dim
索引。- 否則傳回
[scatter_indices[update_scatter_index]]
。
- 針對
axes(inputs[0])
中的d_input
,- 如果為
d_input = scatter_dims_to_operand_dims[d_start]
,則為full_start_index[d_input] = start_index[d_start]
。 - 否則傳回
full_start_index[d_input] = 0
。
- 如果為
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
,其中wi
是update_window_index
中的個別元素,而0
會插入inserted_window_dims
的索引。result_index = full_start_index + full_window_index
.
因此,results = exec(schedule, inputs)
,其中:
schedule
是index_space(updates[0])
的實作定義的排列方式。exec([update_index, ...], results) = exec([...], updated_results)
,其中:- 如果
result_index
在shape(results...)
的邊界內 updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
是results
的副本,並將results...[result_index]
設為updated_values...
。- 你也可以
updated_results = results
.
- 如果
exec([], results) = results
.
如果 indices_are_sorted
是 true
,實作會假設 scatter_indices
會依照 scatter_dims_to_operand_dims
排序,否則行為未定義。更正式地是,針對 indices(result)
中的所有 i1 < i2
,full_start_index(i1)
<= full_start_index(i2)
。
如果 unique_indices
為 true
,則實作會假設所有 result_index
索引均散發不重複。如果 unique_indices
為 true
,但索引被散佈為重複,表示行為未定義。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1)、(C2)、(C4-C6)、(C10)、(C13)、(C15-C16) |
(I2)。 | scatter_indices |
整數類型的張量 | (C4)、(C11)、(C14) |
(I3)。 | updates |
各種張量或每個張量量化張量 | (C3-C6)、(C8) |
(I4)。 | update_window_dims |
si64 類型的 1 維張量常數 |
(C2)、(C4)、(C7)、(C8) |
(I5) | inserted_window_dims |
si64 類型的 1 維張量常數 |
(C2)、(C4)、(C9)、(C10) |
(I6)。 | scatter_dims_to_operand_dims |
si64 類型的 1 維張量常數 |
(C11-C13) |
(I7) | index_vector_dim |
si64 類型的常數 |
(C4)、(C11)、(C14) |
(I8) | indices_are_sorted |
i1 類型的常數 |
|
(I9) | unique_indices |
i1 類型的常數 |
|
(I10)。 | update_computation |
函式 | (C15) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量或每個張量量化張量 | (C15-C17) |
限制
- (C1)
same(shape(inputs...))
。 - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
。 - (C3)
same(shape(updates...))
。 - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
,其中:update_scatter_dim_sizes = shape(scatter_indices)
,但資料不會包含與index_vector_dim
對應的scatter_indices
尺寸大小。update_window_dim_sizes <= shape(inputs[0])
,但不會包含inputs[0]
中與inserted_window_dims
對應的尺寸大小。combine
會將update_scatter_dim_sizes
置於與update_scatter_dims
和update_window_dim_sizes
對應的軸上,並放在與update_window_dims
對應的軸。
- (C5)
0 < size(inputs) = size(updates) = N
。 - (C6)
element_type(updates...) = element_type(inputs...)
。 - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
。 - (C8)
0 <= update_window_dims < rank(updates[0])
。 - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
。 - (C10)
0 <= inserted_window_dims < rank(inputs[0])
。 - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
。 - (C12)
is_unique(scatter_dims_to_operand_dims)
。 - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
。 - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
。 - (C15)
update_computation
具有(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
類型,其中is_promotable(element_type(inputs[i]), Ei)
。 - (C16)
shape(inputs...) = shape(results...)
。 - (C17)
element_type(results[i]) = Ei
適用於[0,N)
的所有i
。
範例
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
選取
語義
產生 result
張量,並根據 pred
的對應元素值,從 on_true
或 on_false
張量選取每個元素。正式名稱為 result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
,其中 pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
。如果是量化類型,請執行 dequantize_select_quantize(pred, on_true, on_false, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | pred |
i1 類型的張量 |
(C1)。 |
(I2)。 | on_true |
張量或每個張量化張量 | (C1-C2) |
(I3)。 | on_false |
張量或每個張量化張量 | (C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C2) |
限制
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
。 - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
。
範例
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
語義
使用 select
根據 input
張量的 reduce_window
範圍,使用 scatter
分散 source
張量的值,然後產生 result
張量。
下圖顯示如何使用具體範例,從 operand
和 source
計算 result
中的元素。
更正式:
含有下列輸入內容的
selected_values = reduce_window_without_init(...)
:- `inputs = [運算元]。
- 依原樣使用
window_dimensions
、window_strides
和padding
。 base_dilations = windows_dilations = 1
.body
定義為:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
其中
E = element_type(operand)
和reduce_window_without_init
的運作方式與reduce_window
類似,但基礎reduce
的schedule
(請參閱 reduce) 不包含 init 值。目前未指出如果對應的視窗沒有值,會發生什麼情況。(#731)。result[result_index] = reduce([source_values], [init_value], [0], scatter)
,其中:source_values = [source[source_index] for source_index in source_indices]
.- 如果
selected_values[source_index]
具有來自operand_index
的operand
元素,則為selected_index(source_index) = operand_index
。 source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1-C4)、(C6)、(C8-C11) |
(I2)。 | source |
張量或每個張量化張量 | (C1)、(C2) |
(I3)。 | init_value |
0 維張量或按張量化張量 | (C3) |
(I4)。 | window_dimensions |
si64 類型的 1 維張量常數 |
(C2)、(C4)、(C5) |
(I5) | window_strides |
si64 類型的 1 維張量常數 |
(C2)、(C6)、(C7) |
(I6)。 | padding |
si64 類型的 2 維張量常數 |
(C2)、(C8) |
(I7) | select |
函式 | (C9) |
(I8) | scatter |
函式 | (C10) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C11-C12) |
限制
- (C1)
element_type(operand) = element_type(source)
。 - (C2)
shape(source) = num_windows
,其中:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
。 - (C4)
size(window_dimensions) = rank(operand)
。 - (C5)
0 < window_dimensions
。 - (C6)
size(window_strides) = rank(operand)
。 - (C7)
0 < window_strides
。 - (C8)
shape(padding) = [rank(operand), 2]
。 - (C9)
select
具有類型(tensor<E>, tensor<E>) -> tensor<i1>
,其中E = element_type(operand)
。 - (C10)
scatter
具有類型(tensor<E>, tensor<E>) -> tensor<E>
,其中is_promotable(element_type(operand), E)
。 - (C11)
shape(operand) = shape(result)
。 - (C12)
element_type(result) = E
。
範例
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
傳送
語義
將 inputs
傳送至管道 channel_id
,並產生 result
權杖。
如果 is_host_transfer
是 true
,則作業會將資料轉移至主機。否則,系統會將資料轉移到其他裝置。這代表的是實作定義。這個旗標會複製 channel_type
中提供的資訊,因此未來我們打算只會保留其中一個 (#666)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或量化張量 | |
(I2)。 | token |
token |
|
(I3)。 | channel_id |
si64 類型的常數 |
|
(I4)。 | channel_type |
DEVICE_TO_DEVICE 和 DEVICE_TO_HOST 的列舉項目 |
(C1)。 |
(I5) | is_host_transfer |
i1 類型的常數 |
(C1)。 |
輸出
名稱 | 類型 |
---|---|
result |
token |
限制
- (C1)
channel_type
定義為:DEVICE_TO_HOST
表示is_host_transfer = true
,- 否則傳回
DEVICE_TO_DEVICE
。
範例
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
語義
以 rhs
位元數對 lhs
張量執行元素相關左移作業,並產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數類型的張量 | (C1)。 |
(I2)。 | rhs |
整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
語義
以位元數lhs
張數,rhs
張數對元素做出元素右移運算,然後產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數類型的張量 | (C1)。 |
(I2)。 | rhs |
整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
語義
對 lhs
張量執行 rhs
張量的元素右移邏輯,然後產生 result
張量。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數類型的張量 | (C1)。 |
(I2)。 | rhs |
整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
簽署
語義
依序傳回 operand
元素的符號,並產生 result
張量。
更正式地說,每個元素 x
都能使用 Python 語法來表示語意,如下所示:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
如果是量化類型,請執行 dequantize_op_quantize(sign, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
帶正負號整數、浮點、複雜類型,或每個張量的量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
正弦
語義
對 operand
張量執行元素相關正弦運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
sin
。 - 複數:複數的正弦。
- 量化類型:
dequantize_op_quantize(sine, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
配量
語義
使用靜態計算的啟動索引從 operand
擷取切片,然後產生 result
張量。start_indices
包含每個維度的切片起始索引,limit_indices
包含每個維度的切片的結尾索引 (不含),而 strides
則包含每個維度的步距。
正式結果是 result[result_index] = operand[operand_index]
,其中 operand_index = start_indices + result_index * strides
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或每個張量化張量 | (C1-C3)、(C5) |
(I2)。 | start_indices |
si64 類型的 1 維張量常數 |
(C2)、(C3)、(C5) |
(I3)。 | limit_indices |
si64 類型的 1 維張量常數 |
(C2)、(C3)、(C5) |
(I4)。 | strides |
si64 類型的 1 維張量常數 |
(C2)、(C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或每個張量化張量 | (C1)、(C5) |
限制
- (C1)
element_type(operand) = element_type(result)
。 - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
。 - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
。 - (C4)
0 < strides
。 - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
。
範例
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
排序
語義
根據 comparator
並按維度 dimension
,一併排序 inputs
的 1 維切片,並產生 results
。
與其他作業中的類似輸入內容不同,dimension
允許負值,但使用下述語意。日後,基於一致性原因,可能會禁止這項操作。(#1377)。
如果 is_stable
為 true,排序功能就會穩定,也就是視為與比較元素相等的元素相對順序。如果有多個輸入,e1
和 e2
兩個元素會視為相等,只有在 comparator(e1, e2) = comparator(e2, e1) = false
時才會被視為相等。請參閱下方的正規化以如何歸納為多個輸入內容。
更正式地是,針對 index_space(results[0])
中的所有 result_index
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
,其中riN
是result_index
中的個別元素,而:
會在adjusted_dimension
插入。inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- 其中
sort
會以非遞減順序排序 1 維配量,以預期在左側引數小於右側第二個引數時,comparator_together
會傳回true
。 def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | inputs |
各種張量或每個張量量化張量 | (C1-C5) |
(I2)。 | dimension |
si64 類型的常數 |
(C4) |
(I3)。 | is_stable |
i1 類型的常數 |
|
(I4)。 | comparator |
函式 | (C5) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量或每個張量量化張量 | (C2)、(C3) |
限制
- (C1)
0 < size(inputs)
。 - (C2)
type(inputs...) = type(results...)
。 - (C3)
same(shape(inputs...) + shape(results...))
。 - (C4)
-R <= dimension < R
,其中R = rank(inputs[0])
。 - (C5)
comparator
具有(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
類型,其中Ei = element_type(inputs[i])
。
範例
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
語義
對 operand
張量執行元素層級平方根運算,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
squareRoot
。 - 複數:複數的平方根。
- 量化類型:
dequantize_op_quantize(sqrt, operand, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
語義
對元素執行元素減法將兩個張量 lhs
和 rhs
相減,然後產生 result
張量。根據元素類型執行以下操作:
- 整數:整數減。
- 浮點值:由 IEEE-754 距離
subtraction
。 - 複數:複數的減法。
- 量化類型:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
(I2)。 | rhs |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
整數、浮點、複雜類型或每個張量量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
範例
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
坦赫
語義
對 operand
張量執行元素層級的雙曲線正切作業,並產生 result
張量。根據元素類型執行以下操作:
- 浮點值:由 IEEE-754 距離
tanh
。 - 如果數字為複數:複雜的雙曲正切。
- 量化類型:
dequantize_op_quantize(tanh, operand, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_type(operand) = baseline_type(result)
。
範例
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
轉置
語義
使用 permutation
保留 operand
張量的維度,並產生 result
張量。更正式的,也就是 result[result_index] = operand[operand_index]
,其中 result_index[d] = operand_index[permutation[d]]
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
張量或量化張量 | (C1-C4) |
(I2)。 | permutation |
si64 類型的 1 維張量常數 |
(C2-C4) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
張量或量化張量 | (C1)、(C3-C4) |
限制
- (C1)
element_type(result)
的提供者:element_type(operand)
(如果為!is_per_axis_quantized(operand)
)。element_type(operand)
除外,但quantization_dimension(operand)
和quantization_dimension(result)
可能不同。
- (C2)
permutation
是range(rank(operand))
的排列方式。 - (C3)
shape(result) = dim(operand, permutation...)
。 - (C4) 如果為
is_per_axis_quantized(result)
,則quantization_dimension(operand) = permutation(quantization_dimension(result))
。
範例
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
語義
以較低或上三角係數來解開聯立線性方程式的批次。
更正式地說,由於 a
和 b
,當 left_side
為 true
時,result[i0, ..., iR-3, :, :]
是 op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
的解決方案,當 left_side
為 false
時,則解決 op(a)
變數的 x
解決方案,其中 op(a)
可為 transpose_a
,可以是下列其中一項:x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
NO_TRANSPOSE
:依原樣使用a
執行作業。TRANSPOSE
:對轉置a
執行作業。ADJOINT
:對a
的共錐轉出執行作業。
系統只會在 lower
為 true
或 a
的上三角形時,從 a
的底部三角形讀取輸入資料。輸出資料會在同一個三角形中傳回,其他三角形的值則是實作定義。
如果 unit_diagonal
為 true,實作會假設 a
的對角元素等於 1,否則行為未定義。
如果是量化類型,請執行 dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | a |
浮點、複雜類型或每個張量化張量的張量 | (C1-C3) |
(I2)。 | b |
浮點、複雜類型或每個張量化張量的張量 | (C1-C4) |
(I3)。 | left_side |
i1 類型的常數 |
(C3) |
(I4)。 | lower |
i1 類型的常數 |
|
(I5) | unit_diagonal |
i1 類型的常數 |
|
(I6)。 | transpose_a |
NO_TRANSPOSE 、TRANSPOSE 和 ADJOINT 的列舉項目 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點、複雜類型或每個張量化張量的張量 | (C1)。 |
限制
- (C1)
baseline_element_type(a) = baseline_element_type(b)
。 - (C2)
2 <= rank(a) = rank(b) = R
。 - (C3)
shape(a)
和shape(b)
之間的關係定義如下:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
。
範例
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
語義
使用值 val
產生 result
元組。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | val |
變異數值 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
tuple | (C1)。 |
限制
- (C1)
result
屬於tuple<E0, ..., EN-1>
類型,其中Ei = type(val[i])
。
範例
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
語義
根據 operand
類型定義的量化參數,將量化張量 operand
的元素層級轉換為浮點張量 result
。
正式的是 result = dequantize(operand)
。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
量化張量 | (C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
浮點類型的張量 | (C1)、(C2) |
限制
- (C1)
shape(operand) = shape(result)
。 - (C2)
element_type(result) = expressed_type(operand)
。
範例
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
語義
根據 result
類型定義的量化參數,將浮點張量或量化張量 operand
轉換成量化張量 result
的元素層級轉換。
更正式的
- 如果
is_float(operand)
:result = quantize(operand, type(result))
.
- 如果
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
浮點或量化類型的張量 | (C1)、(C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
量化張量 | (C1)、(C2) |
限制
- (C1)
shape(operand) = shape(result)
。 - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
。
範例
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
時
語義
當 cond
函式輸出 true
時,執行 body
函式 0 次以上會產生輸出內容。正式的語意可透過 Python 語法表示,如下所示:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
無限迴圈的行為是待定。(#383)。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | operand |
各種張量、量化張量或符記 | (C1-C3) |
(I2)。 | cond |
函式 | (C1)。 |
(I3)。 | body |
函式 | (C2) |
輸出
名稱 | 類型 | 限制 |
---|---|---|
results |
各種張量、量化張量或符記 | (C3) |
限制
- (C1)
cond
具有(T0, ..., TN-1) -> tensor<i1>
類型,其中Ti = type(operand[i])
。 - (C2)
body
具有(T0, ..., TN-1) -> (T0, ..., TN-1)
類型,其中Ti = type(operand[i])
。 - (C3)
type(results...) = type(operand...)
。
範例
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
語義
執行兩個張量 lhs
和 rhs
的元素層級 XOR,並產生 result
張量。根據元素類型執行以下操作:
- 布林值:邏輯 XOR。
- 整數:位元 XOR。
輸入內容
標籤 | 名稱 | 類型 | 限制 |
---|---|---|---|
(I1)。 | lhs |
布林值或整數類型的張量 | (C1)。 |
(I2)。 | rhs |
布林值或整數類型的張量 | (C1)。 |
輸出
名稱 | 類型 | 限制 |
---|---|---|
result |
布林值或整數類型的張量 | (C1)。 |
限制
- (C1)
type(lhs) = type(rhs) = type(result)
。
範例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
執行
依序執行
系統會透過提供輸入值給 main
函式並計算輸出值來執行 StableHLO 程式。系統會執行對應 return
運算中的運算圖形,來計算函式的輸出值。
只要執行順序與 Dataflow 一致,系統就會實作執行順序,即作業是在使用前執行。在 StableHLO 中,所有連帶效果運算都會使用一個符記並產生一個符記 (可透過 after_all
將多個符記加倍為一個符記),因此副作用的執行順序也會與 Dataflow 保持一致。上述範例程式可能的執行順序為 %0
→ %1
→ %2
→ %3
→ %4
→ return
或 %3
→ %0
→ %2
→ %4
→ return
。%1
更正式的,StableHLO 程序結合了:1) StableHLO 程式、2) 作業狀態 (尚未執行、已經執行) 和 3) 程序處理的中繼值。程序從 main
函式的輸入值開始,逐步查看運算狀態和中間值的更新作業,並結束輸出值。最終正規化為待定。(#484)。
平行執行
可平行執行 StableHLO 程式,並依 num_partitions
整理為 num_replicas
的 2D 程序格線,兩者都具有類型 ui32
。
在 StableHLO 程序格線中,有 num_replicas * num_partitions
個 StableHLO 程序會同時執行。每個程序都有專屬的 process_id = (replica_id, partition_id)
,其中 replica_ids = range(num_replicas)
和 partition_ids = range(num_partitions)
中的 replica_id
,兩者類型皆為 ui32
。partition_id
每個程式都會以靜態方式得知程序格線的大小 (未來,我們計劃將其使其成為 StableHLO 程式 #650 中的明確部分),而且每個程序在程序格線中的位置都是靜態的。每個程序都可以透過 replica_id
和 partition_id
運算存取程序格線內的位置。
在程序格線中,所有程式皆可相同 (在「單一程式,多重資料」樣式中),每一種 (在「多重計畫」、「多重資料」樣式中) 皆可相同,或者也可以相同。我們預計日後會支援定義平行 StableHLO 程式的其他慣用語,包括 GSPMD (#619)。
在程序格線中,程序主要彼此獨立 - 程序具有不同的作業狀態、獨立的輸入/中間/輸出值,以及大多數運算,會在程序之間分開執行,但下文所述的少數集體運算除外。
由於大部分的運算只會使用相同程序的值,因此使用這些值的名稱通常並不明確。不過,在描述共同運算的語意時,會不足並引發標記 name@process_id
參照特定程序中的值 name
。(在這種情況下,不符資格的 name
可視為 name@(replica_id(), partition_id())
的簡寫)。
程序之間的執行順序是由實作定義,但點對點通訊與集體運算而引入的同步處理除外 (如下所述)。
點對點通訊
StableHLO 程序可透過 StableHLO 管道相互通訊。管道會以 si64
類型的正 ID 表示。透過各種作業,您可以將值傳送至管道並從管道接收。
進一步的正式化,例如管道 ID 來源、程序如何得知程序,以及所導入的同步處理類型 (#484)。
串流通訊
每個 StableHLO 程序都可存取兩種串流介面:
- 可以讀取的動態內廣告。
- 可寫入的 Outfeed。
與管道不同,管道用於在程序之間進行通訊,因此在兩者的兩端都有程序,而動態內和外動態則由其他的實作定義。
如果需要正式化,例如串流通訊對執行順序的影響,以及其執行何種同步處理作業,均未定案。(#484)。
集體作業
StableHLO 中有六項共同作業:all_gather
、all_reduce
、all_to_all
、collective_broadcast
、collective_permute
和 reduce_scatter
。這些作業會將 StableHLO 程序中的程序拆分為 StableHLO 程序群組,並在每個程序群組內執行聯合運算,而獨立於其他程序群組之外。
在每個程序群組中,共同運算可能會造成同步處理障礙。更進一步的正式程序,例如分析此同步處理作業的確切發生時間、程序如何達到此障礙,以及不然會發生的情況 (#484)。
如果程序群組涉及跨分區通訊 (亦即程序群組中有分區 ID 不同的程序),則共同運算的執行作業需要管道,且共同運算必須提供 si64
類型的 channel_id
。跨備用資源通訊不需要管道。
由共同運算執行的運算僅適用於個別運算,在上述的個別運算區段中說明。不過,程序格線拆分為程序群組採用的策略會在這些作業之間共用,本節將進行說明。更正式的,StableHLO 支援下列四種策略。
cross_replica
每個程序群組只會執行跨備用資源通訊。此策略會採用 replica_groups
- 備用資源 ID 清單,並用 partition_ids
計算 replica_groups
的笛卡兒乘積。replica_groups
必須有不重複的元素,並涵蓋所有 replica_ids
。更正式的做法是使用 Python 語法:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,針對 replica_groups = [[0, 1], [2, 3]]
和 num_partitions = 2
,cross_replica
會產生 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
。
cross_partition
每個程序群組只會進行跨分區通訊。此策略會採用 partition_groups
- 分區 ID 清單,並透過 replica_ids
計算 partition_groups
的笛卡兒乘積。partition_groups
必須有不重複的元素,並涵蓋所有 partition_ids
。更正式,是使用 Python 語法:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,針對 partition_groups = [[0, 1]]
和 num_replicas = 4
,cross_partition
會產生 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
。
cross_replica_and_partition
每個程序群組內可能會同時進行跨備用資源和跨分區通訊。這項策略會採用 replica_groups
- 備用資源 ID 清單,並以 partition_ids
計算每個 replica_group
的笛卡兒產品。replica_groups
必須有不重複的元素,並涵蓋所有 replica_ids
。更正式,是使用 Python 語法:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,針對 replica_groups = [[0, 1], [2, 3]]
和 num_partitions = 2
,cross_replica_and_partition
會產生 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
。
flattened_ids
這項策略會採用 flattened_id_groups
- 以 replica_id * num_partitions + partition_id
形式表示的「整併」程序 ID 清單,並將其轉換成程序 ID。flattened_id_groups
必須有不重複的元素,並涵蓋所有 process_ids
。更正式,是使用 Python 語法:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
例如,如果是 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
、num_replicas = 4
和 num_partitions = 2
,flattened_ids
會產生 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
。
準確
StableHLO 目前不保證數值準確度,但日後可能會改變。(#1156)。
錯誤
StableHLO 程式會透過個別作業的一系列限制來驗證 StableHLO 程式,該程式會在執行時間之前排除許多錯誤類別。 但是,錯誤狀況仍可能發生,例如透過整數溢位、外界存取等。除非明確呼叫,否則這些錯誤都會導致實作定義的行為,但這項設定日後可能會改變 (#1157)。
除了這項規則以外,StableHLO 程式中的浮點例外狀況還具有明確定義的行為。如果作業發生由 IEEE-754 標準定義的例外狀況 (無效作業、除以零、溢位、反向溢位或不完全相同的例外狀況),系統就會產生預設結果 (依標準定義),並在不提高對應的狀態旗標的情況下繼續執行作業,與標準的 raiseNoFlag
例外狀況處理類似。但定義了非標準運算 (例如複雜算術和特定傳輸函式) 的例外狀況。
Notation
為描述語法,本文件使用經過修改的 EBNF 語法 (ISO/IEC 14977:1996、Wikipedia) 變體,其中有兩項修改:1) 規則是使用 ::=
而非 =
定義。
2) 串連表示要以其他形式表示,而不是 ,
。
為描述語意 (即在「類型」、「常數」和「Ops」區段內),我們使用以 Python 語法為基礎的公式,並支援簡要表示陣列運算,如下所示。這種方法適用於小片段程式碼片段,但在極少數情況下,如果需要較大的程式碼,我們會使用一律明確引入的香草 Python 語法。
公式
現在,讓我們根據 dot_general
規格的範例,瞭解公式的運作方式。這項作業的其中一個限制如下:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
。
這個公式使用的名稱來自兩個來源:1) 全域函式,即 dim
、2) 對應程式元素的成員定義,即 dot_general
的「Inputs」區段中定義的 lhs
、lhs_batching_dimensions
、rhs
和 rhs_batching_dimensions
輸入內容。
如上所述,這個公式的語法以 Python 為基礎,加上一些以精簡為導向的擴充功能。為了簡單說明這個公式 我們先將其轉換為變換的 Python 語法
A) 在這些公式中,我們使用 =
表示等式,因此取得 Python 語法的第一步將 =
替換為 ==
,如下所示:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
。
B) 此外,這些公式支援刪節號 (...
),可將純量運算式轉換成張量運算式。簡單來說,f(xs...)
大致是指「針對張量 xs
中的每個純量 x
,計算純量 f(x)
,然後再將這些純量結果傳回做為張量結果」。在香草 Python 語法中,我們的公式範例會變為:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
。
藉助刪節號,通常可以避免在個別純量的層級工作。不過在棘手情況下,您可能會在 gather
規格的 start_indices[bi0, ..., :, ..., biN]
公式中使用較低層級的半不正式語法。簡而言之,我們無法提供將這類語法翻譯成香草 Python 的確切正式正規,原因是在每個情況下,依個別情況仍易於理解。如果部分公式看起來不透明,請告訴我們,我們會嘗試改善這些公式。
此外,您會注意到公式使用刪節號來擴充所有類型的清單,包括張量、張量清單 (例如因張數不同的張量清單) 等。這個領域又沒有確切的正式程度 (例如清單甚至不是 StableHLO 型別系統的一部分)
C) 我們採用的最終值得注意的車輛為隱性廣播。雖然 StableHLO 運算集不支援隱含廣播,但公式也可在精簡服務中使用。簡單來說,如果在預期有張量的情況下使用純量,純量就會播送至預期的形狀。
如要延續 dot_general
範例,以下是另一項限制:0 <= lhs_batching_dimensions < rank(lhs)
。如 dot_general
規格所定義,lhs_batching_dimensions
是張量,但 0
和 rank(lhs)
都是純量。套用隱式廣播後,公式會變成 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
。
套用至特定 dot_general
運算時,這個公式會評估布林值的張量。將公式當做限制使用時,如果公式評估為 true
或只有 true
元素的張量,該限制就會生效。
名稱
在公式中,詞法範圍包括:1) 全域函式、2) 成員定義,
3) 當地定義。下方提供全域函式清單。元素定義清單取決於標記該標記的程式元素:
- 針對作業,成員定義包含在「Inputs」(輸入) 和「Outputs」(輸出) 區段中導入的名稱。
- 除此之外,成員定義都涵蓋程式元素的結構部分,並以對應的 EBNF 非終端機命名。在多數情況下,這些結構零件的名稱可透過將非終端的名稱轉換為蛇的大小寫 (例如
IntegerLiteral
=>integer_literal
),但有時名稱會在程序中縮寫縮寫 (例如:QuantizationStorageType
輸出 =>storage_type
),在這種情況下,名稱在專有部分與「Inputs」(輸入)/「運算」 (運算) 相關。 - 此外,成員定義一律包含
self
,以參照對應的程式元素。
值
公式評估時會使用下列類型的值:1) Value
(實際值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;這些值一律知道其類型)、2) Placeholder
(未來值,例如 lhs
、rhs
或 result
;其實際值尚未得知,只有類型才會知道類型)、3) Type
(在「函式」一節中定義的類型)、「4」) Function
(「函式」一節中定義的類型)。
根據內容而定,名稱可能參照不同的值。具體來說,運算 (以及其他程式元素的對等項目) 的「語意」區段會定義執行階段邏輯,因此所有輸入內容都能以 Value
的形式提供。相反地,運算 (和對等項目) 的「限制」部分會定義「編譯時間」邏輯,也就是通常在執行階段之前執行的項目,因此只有常數輸入內容可做為 Value
使用,其他輸入內容則只能做為 Placeholder
使用。
名稱 | 在「語意」中 | 在「限制」中 |
---|---|---|
全域函式 | Function |
Function |
常數輸入內容 | Value |
Value |
非常數的輸入內容 | Value |
Placeholder |
輸出 | Value |
Placeholder |
本機定義 | 視定義而定 | 視定義而定 |
我們來看看一個 transpose
運算範例:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
這項作業的 permutation
是常數,因此可在語意和限制中以 Value
的形式提供。相反地,operand
和 result
可做為語意中的 Value
使用,但僅限做為限制條件中的 Placeholder
。
函式
類型建構
沒有可用來建構類型的函式。而是直接使用類型語法,因為這類語法通常較精簡。例如:(tensor<E>, tensor<E>) -> (tensor<E>)
,而非 function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
。
類型的函式
element_type
是在張量類型和量化張量類型上定義,並分別回傳相應TensorType
或QuantizedTensorType
的TensorElementType
或QuantizedTensorElementType
部分。
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
是is_quantized(x) and quantization_dimension(x) is not None
的捷徑。is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
是is_quantized(x) and quantization_dimension(x) is None
的捷徑。is_promotable(x: Type, y: Type) -> bool
會檢查x
類型是否可升級為y
類型。當x
和y
為QuantizedTensorElementType
時,促銷活動只會套用至storage_type
。此特定升級版本目前用於減少運算作業 (詳情請參閱 RFC 相關說明)。
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
是is_quantized_tensor_element_type(x)
的捷徑。is_type_name(x: Value | Placeholder | Type) -> Value
:適用於所有類型。舉例來說,如果x
是FloatType
,is_float(x)
會傳回true
。如果x
是值或預留位置,此函式是is_type_name(type(x))
的捷徑。max_value(x: Type) -> Value
會傳回TensorElementType
的最大值。如果x
不是TensorElementType
,就會傳回None
。min_value(x: Type) -> Value
會傳回TensorElementType
的最小可能值。如果x
不是TensorElementType
,就會傳回None
。member_name(x: Value | Placeholder | Type) -> Any
:適用於所有類型的成員定義member_name
。舉例來說,tensor_element_type(x)
會傳回對應TensorType
的TensorElementType
部分。如果x
是值或預留位置,此函式是member_name(type(x))
的捷徑。如果x
類型不包含適當成員,或是該類型的值或預留位置,則會傳回None
。
值的建構
operation_name(*xs: Value | Type) -> Value
:適用於所有作業。舉例來說,add(lhs, rhs)
會取用lhs
和rhs
兩個張量值,並傳回利用這些輸入內容評估add
運算的輸出結果。對broadcast_in_dim
等某些作業來說,其輸出類型為「載入傳輸」,也就是評估作業所需的項目。在此情況下,函式會將這些類型視為引數。
用於值的函式
to_destination_type(x: Value, destination_type: Type) -> Value
是在張量上定義,並根據type(x)
和destination_type
傳回x
的轉換值,如下所示:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
目前我們已早期探討合併 convert
、uniform_quantize
和 uniform_dequantize
作業 (#1576)。合併後,我們不需要上述函式,可改用 convert
的作業名稱。
is_nan(x: Value) -> Value
是在張量上定義,如果x
的所有元素都是NaN
或false
,則會傳回true
。如果x
不是張量,會傳回None
。is_sorted(x: Value) -> Value
已在張量定義,如果x
的元素是按照其索引的遞增字母順序排序,則會傳回true
。false
如果x
不是張量,就會傳回None
。is_unique(x: Value) -> Value
是在張量上定義,如果x
沒有重複的元素,則會傳回true
,否則會傳回false
。如果x
不是張量,會傳回None
。member_name(x: Value) -> Any
已定義所有值的成員定義member_name
。舉例來說,real_part(x)
會傳回對應ComplexConstant
的RealPart
部分。如果x
的值不是包含適當成員的值,則會傳回None
。same(x: Value) -> Value
是在張量上定義,如果x
的元素彼此相等,則傳回true
,否則會傳回false
。如果張量沒有元素,系統會將其視為「彼此相等」,即函式會傳回true
。如果x
不是張量,就會傳回None
。split(x: Value, num_results: Value, axis: Value) -> Value
是在張量上定義,且會傳回軸axis
軸的x
配量。num_results
如果x
不是張量或dim(x, axis) % num_results != 0
,就會傳回None
。
形狀計算
axes(x: Value | Placeholder | Type) -> Value
是range(rank(x))
的捷徑。dim(x: Value | Placeholder | Type, axis: Value) -> Value
是shape(x)[axis]
的捷徑。dims(x: Value | Placeholder | Type, axes: List) -> List
是list(map(lambda axis: dim(x, axis), axes))
的捷徑。index_space(x: Value | Placeholder | Type) -> Value
是在張量上定義,並會針對按字母順序排列的對應TensorType
傳回size(x)
索引,即[0, ..., 0]
、[0, ..., 1]
、...、shape(x) - 1
。如果x
不是張量類型、量化張量類型、值或是其中一種類型的預留位置,就會傳回None
。rank(x: Value | Placeholder | Type) -> Value
是size(shape(x))
的捷徑。shape(x: Value | Placeholder | Type) -> Value
是透過member_name
在「類型上的函式」部分中定義。size(x: Value | Placeholder | Type) -> Value
是reduce(lambda x, y: x * y, shape(x))
的捷徑。
量化運算
def baseline_element_type(x: Value | Placeholder | Type) -> Type
是element_type(baseline_type(x))
的捷徑。baseline_type
已在張量類型和量化張量類型上定義,並轉換為「基準」,即形狀相同但帶有量化參數的類型會重設為預設值。這很適合做為平均比較張量和量化張量類型比較的實用技巧。以量化類型來說,這可讓比較忽略量化參數的類型 (shape
、storage_type
、expressed_type
、storage_min
、storage_max
和quantization_dimension
(針對每軸量化類型) 必須全部相符,但scales
和zero points
則可能有所不同。
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
是在量化張量類型上定義,並將其轉換為浮點張量類型。方法是將代表儲存空間類型的整數值,轉換為以零點和量化元素類型相關聯的量化元素,轉換為表示類型的對應浮點值。
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
是在浮點張量類型上定義,並轉換為量化張量類型。方法是使用與量化元素類型相關聯的零點和縮放,將表示類型的浮點值轉換為儲存類型的對應整數值。
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
可用來在量化張量上指定元素相關計算。這樣做會去量化,即將量化元素轉換成其表達類型,然後執行運算,然後量化結果 (即將結果轉回其儲存空間類型)。目前此函式僅適用於按張量量化。系統仍在處理每軸量化。(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
網格運算
cross_partition(replica_groups: Value) -> Value
。請參閱上述的「cross_copy」部分。cross_replica(replica_groups: Value) -> Value
。請參閱上述的「cross_copy」部分。cross_replica_and_partition(replica_groups: Value) -> Value
。請參閱上方的「cross_pli_and_partition」一節。flattened_ids(replica_groups: Value) -> Value
。請參閱上述的「flattened_ids」一節。