StableHLO 是机器学习 (ML) 模型中高级操作 (HLO) 的一组操作。StableHLO 可作为不同机器学习框架和机器学习编译器之间的可移植性层:生成 StableHLO 程序的机器学习框架与使用 StableHLO 程序的机器学习编译器兼容。
我们的目标是,通过在各种机器学习框架(例如 TensorFlow、JAX 和 PyTorch)和机器学习编译器(例如 XLA 和 IREE)之间创建更多互操作性,来简化和加速机器学习开发。为此,本文档提供了 StableHLO 编程语言的规范。
此规范包含三个主要部分。首先,程序部分介绍了 StableHLO 程序的结构,该程序由 StableHLO 函数组成,而 StableHLO 函数本身又由 StableHLO 操作组成。在该结构中,Ops 部分指定了各个操作的语义。执行部分提供了在程序中一起执行的所有这些操作的语义。最后,“表示法”部分讨论了整个规范中使用的表示法。
如需查看之前版本的 StableHLO 的规范,请打开感兴趣的已标记的版本的代码库。例如,StableHLO v0.19.0 规范。如需查看 StableHLO 在每次次要版本升级时发生的更改,请参阅 VhloDialect.td 中的版本日志。
计划
Program ::= {Func}
StableHLO 程序由任意数量的 StableHLO 函数组成。以下是一个示例程序,其中包含一个函数 @main,该函数有 3 个输入(%image、%weights 和 %bias)和 1 个输出。函数的正文有 6 个操作。
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
函数
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO 函数(也称为命名函数)具有标识符、输入/输出和正文。未来,我们计划为函数引入更多元数据,以实现与 HLO 的更好兼容性(#425、#626、#740、#744)。
标识符
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO 标识符与许多编程语言中的标识符类似,但有两个特殊之处:1) 所有标识符都有标记,用于区分不同类型的标识符;2) 值标识符可以是纯数字,以简化 StableHLO 程序的生成。
类型
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO 类型分为值类型(也称为一等类型),表示 StableHLO 值;以及非值类型,用于描述其他程序元素。StableHLO 类型与许多编程语言中的类型类似,主要特点是 StableHLO 的领域特定性质,这会导致一些不寻常的结果(例如,标量类型不是值类型)。
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
张量类型表示张量,即多维数组。它们具有形状和元素类型,其中形状表示非负或未知的维度大小,按相应维度(也称为轴)的升序排列,编号从 0 到 R-1。维度数 R 称为秩。例如,tensor<2x3xf32> 是一个形状为 2x3、元素类型为 f32 的张量类型。它有两个维度(或者说两个轴)- 第 0 个维度和第 1 个维度 - 大小分别为 2 和 3。其等级为 2。
形状可以是部分未知或完全未知(动态),例如 tensor<?x2xf64> 是部分未知,而 tensor<?x?xf64> 是完全未知。动态维度大小使用 ? 表示。无法取消对形状的排名。
未来,我们计划探索将张量类型扩展到维度大小和元素类型之外,例如,纳入布局 (#629) 和稀疏性 (#1078)。
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| 名称 | 类型 | 限制条件 |
|---|---|---|
storage_type |
integer type | (C1-C3)、(C8) |
storage_min |
整数常量 | (C1)、(C3)、(C7) |
storage_max |
整数常量 | (C2)、(C3)、(C7) |
expressed_type |
浮点类型 | (C4) |
quantization_dimension |
可选的整数常量 | (C10-C12) |
scales |
可变数量的浮点常量 | (C4-C6)、(C9)、(C10)、(C13) |
zero_points |
可变数量的整数常量 | (C7-C9) |
量化元素类型表示存储类型的整数值,范围从 storage_min 到 storage_max(含),对应于表达类型的浮点值。对于给定的整数值 i,相应的浮点值 f 可计算为 f = (i - zero_point) * scale,其中 scale 和 zero_point 称为量化参数。在语法中,storage_min 和 storage_max 是可选的,但默认值分别为 min_value(storage_type) 和 max_value(storage_type)。量化元素类型具有以下限制:
- (C1)
type(storage_min) = storage_type。 - (C2)
type(storage_max) = storage_type。 - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)。 - (C4)
type(scales...) = expressed_type。 - (C5)
0 < scales。 - (C6)
is_finite(scales...)。 - (C7)
storage_min <= zero_points <= storage_max。 - (C8)
type(zero_points...) = storage_type。 - (C9)
size(scales) = size(zero_points)。 - (C10) 如果
is_empty(quantization_dimension),则size(scales) = 1。 - (C11)
0 <= quantization_dimension。
目前,QuantizationScale 是一个浮点常量,但人们对基于整数的缩放(用乘数和偏移量表示)非常感兴趣。我们计划在不久的将来探索此功能 (#1404)。
目前,我们正在讨论 QuantizationZeroPoint 的语义,包括类型、值以及量化张量类型中是否只能有一个或可能存在多个零点。根据此次讨论的结果,未来可能会更改有关零分的相关规范 (#1405)。
另一项正在进行的讨论涉及 QuantizationStorageMin 和 QuantizationStorageMax 的语义,以确定是否应针对这些值和量化张量的值施加任何限制 (#1406)。
最后,我们计划探索如何表示未知的比例和零点,这与我们计划探索如何表示未知的维度大小 (#1407) 类似。
量化张量类型表示具有量化元素的张量。这些张量与常规张量完全相同,只是其元素具有量化元素类型,而不是常规元素类型。
在量化张量中,量化可以是按张量进行的,即整个张量只有一个 scale 和 zero_point;也可以是按轴进行的,即有多个 scales 和 zero_points,每个特定维度 quantization_dimension 的切片都有一对。更正式地说,在具有按轴量化的张量 t 中,quantization_dimension 有 dim(t, quantization_dimension) 个切片:t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] 等。第 i 个切片中的所有元素都使用 scales[i] 和 zero_points[i] 作为其量化参数。量化张量类型具有以下限制:
- 对于按张量进行量化:
- 没有其他限制。
- 对于按轴量化:
- (C12)
quantization_dimension < rank(self)。 - (C13)
dim(self, quantization_dimension) = size(scales)。
- (C12)
TokenType ::= 'token'
令牌类型表示令牌,即由某些操作生成和使用的不透明值。令牌用于对操作强制执行顺序,如执行部分中所述。
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
缓冲区类型表示缓冲区。例如,在 XLA 中,缓冲区是具有一致存储的多维数组。与张量类型类似,缓冲区类型具有形状和元素类型,其中形状表示非负或未知的维度大小,按相应维度(也称为轴)的升序排列,编号从 0 到 R-1。维度数 R 称为秩。例如,memref<2x3xf32> 是一个形状为 2x3、元素类型为 f32 的缓冲区类型。它有两个维度(或者说两个轴)- 第 0 个维度和第 1 个维度 - 大小分别为 2 和 3。其等级为 2。
可以使用 custom_call 到 CreateBuffer 或 Pin 分配缓冲区,并通过 custom_call 到 Unpin 取消分配缓冲区。只有 custom_call 操作可以读取和写入缓冲区内的内容。如需了解详情,请参阅 custom_call。
元组类型表示元组,即异构列表。元组是一项旧版功能,仅用于与 HLO 兼容。在 HLO 中,元组用于表示可变输入和输出。在 StableHLO 中,可变数量的输入和输出受到原生支持,并且 StableHLO 中元组的唯一用途是全面表示 HLO ABI,其中例如 T、tuple<T> 和 tuple<tuple<T>> 可能因特定实现而有很大差异。未来,我们计划对 HLO ABI 进行更改,这可能使我们能够从 StableHLO 中移除元组类型 (#598)。
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
元素类型表示张量类型的元素。与许多编程语言不同,这些类型在 StableHLO 中不是一等类型。这意味着 StableHLO 程序无法直接表示这些类型的值(因此,惯例是使用类型为 tensor<T> 的 0 维张量值来表示类型为 T 的标量值)。
- 布尔值类型表示布尔值
true和false。 - 整数类型可以是有符号 (
si) 或无符号 (ui),并且具有支持的位宽(2、4、8、16、32或64)。有符号siN类型表示从-2^(N-1)到2^(N-1)-1(含)的整数值,无符号uiN类型表示从0到2^N-1(含)的整数值。 - 浮点类型可以是以下类型之一:
f8E3M4、f8E4M3和f8E5M2遵循 IEEE-754 惯例的 8 位浮点数。f8E4M3FN和f8E5M2类型分别对应于 FP8 Formats for Deep Learning 中描述的 FP8 格式的E4M3和E5M2编码。f8E4M3FNUZ和f8E5M2FNUZ类型分别对应于用于深度神经网络的 8 位数值格式中所述 FP8 格式的E4M3和E5M2编码。f8E4M3B11FNUZ类型,对应于用于深度神经网络的混合 8 位浮点 (HFP8) 训练和推理中所述的 FP8 格式的E4M3编码。- 与 BFloat16:Cloud TPU 上实现高性能的秘诀中所述的
bfloat16格式对应的bf16类型。 f16、f32和f64类型分别对应于 IEEE 754 标准中描述的binary16(“半精度”)、binary32(“单精度”)和binary64(“双精度”)格式。tf32类型对应于 TensorFloat32 格式,在 StableHLO 中仅受到有限的支持。- OCP 微缩格式规范中描述的
f4E2M1FN、f6E2M3FN、f6E3M2FN和f8E8M0FNUMX(微缩)类型。
- 复数类型表示具有相同元素类型的实部和虚部的复数值。支持的复杂类型包括
complex<f32>(两个部分均为f32类型)和complex<f64>(两个部分均为f64类型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
函数类型表示命名函数和匿名函数。它们具有输入类型(-> 左侧的类型列表)和输出类型(-> 右侧的类型列表)。在许多编程语言中,函数类型是一等类型,但在 StableHLO 中不是。
StringType ::= 'string'
字符串类型表示字节序列。与许多编程语言不同,字符串类型在 StableHLO 中不是一等公民,仅用于为程序元素指定静态元数据。
运维
StableHLO 操作(也称为 op)表示机器学习模型中的一组封闭的高级操作。如上所述,StableHLO 语法在很大程度上受到 MLIR 的启发,虽然 MLIR 不一定是最符合人体工程学的替代方案,但可以说是最适合 StableHLO 的,因为 StableHLO 的目标是在机器学习框架和机器学习编译器之间创建更好的互操作性。
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO 操作(也称为 op)具有名称、输入/输出和签名。该名称由 stablehlo. 前缀和可唯一标识某个受支持操作的助记符 组成。如需查看所有受支持的运算的完整列表,请参阅下文。
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
操作会使用输入并生成输出。输入分为输入值(在执行期间计算)、输入函数(静态提供,因为在 StableHLO 中,函数不是一级值)和输入属性(也是静态提供)。操作所消耗和产生的输入和输出类型取决于其助记符。例如,add op 会消耗 2 个输入值并生成 1 个输出值。相比之下,select_and_scatter 操作会消耗 3 个输入值、2 个输入函数和 3 个输入属性。
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
输入函数(也称为匿名函数)与命名函数非常相似,但有以下区别:1) 它们没有标识符(因此称为“匿名”),2) 它们不声明输出类型(输出类型是从函数中的 return 操作推断出来的)。
输入函数的语法包含一个目前未使用的部分(请参阅上面的 Unused 产品),该部分用于与 MLIR 兼容。在 MLIR 中,有一个更通用的“区域”概念,该区域可以包含多个通过跳转操作连接在一起的操作“块”。这些块具有与 Unused 产品对应的 ID,以便彼此区分。
StableHLO 没有跳转操作,因此 MLIR 语法的相应部分未使用(但仍然存在)。
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
输入属性具有名称和值,该值是受支持的常量之一。它们是为程序元素指定静态元数据的主要方式。例如,concatenate 操作使用属性 dimension 来指定串联输入值的维度。同样,slice 操作使用多个属性(例如 start_indices 和 limit_indices)来指定用于对输入值进行切片的边界。
目前,实际应用中的 StableHLO 程序有时包含本文档中未描述的属性。未来,我们计划将这些属性纳入 StableHLO opset,或者禁止它们出现在 StableHLO 程序中。在此期间,您可以查看以下属性列表:
layout(#629)。mhlo.frontend_attributes(#628)。mhlo.sharding(#619)。output_operand_aliases(#740)。- 位置元数据 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
操作签名包含所有输入值的类型(-> 左侧的类型列表)和所有输出值的类型(-> 右侧的类型列表)。严格来说,输入类型是冗余的,输出类型也几乎总是冗余的(因为对于大多数 StableHLO 操作,输出类型可以从输入中推断出来)。不过,为了与 MLIR 兼容,op 签名特意纳入了 StableHLO 语法。
以下是一个助记符为 select_and_scatter 的操作示例。它使用 3 个输入值(%operand、%source 和 %init_value)、2 个输入函数和 3 个输入属性(window_dimensions、window_strides 和 padding)。请注意,该操作的签名仅包含其输入值的类型(但不包含以内联方式提供的输入函数和属性的类型)。
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
常量
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO 常量具有字面量和类型,两者共同表示一个 StableHLO 值。一般来说,类型是常量语法的一部分,除非类型明确无误(例如,布尔常量明确具有 i1 类型,而整数常量可能具有多种类型)。
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
布尔常量表示布尔值 true 和 false。布尔值常量的类型为 i1。
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
整数常量通过使用十进制或十六进制表示法的字符串来表示整数值。不支持其他进制,例如二进制或八进制。 整数常量具有以下限制:
- (C1)
is_wellformed(integer_literal, integer_type)。
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
浮点常量通过使用十进制或科学记数法的字符串表示浮点值。此外,十六进制表示法可用于直接指定相应类型的浮点格式中的底层位。浮点常量具有以下限制:
- (C1) 如果使用非十六进制表示法,则为
is_wellformed(float_literal, float_type)。 - (C2) 如果使用十六进制表示法,则为
size(hexadecimal_digits) = num_bits(float_type) / 4。
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
复数常量使用实部(在前)和虚部(在后)的列表来表示复数值。例如,(1.0, 0.0) : complex<f32> 表示 1.0 + 0.0i,(0.0, 1.0) : complex<f32> 表示 0.0 + 1.0i。然后,这些部分在内存中的存储顺序由实现定义。复杂常数具有以下限制:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))。 - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))。
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
张量常量使用通过 NumPy 表示法指定的嵌套列表来表示张量值。例如,dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 表示一个张量值,其从索引到元素的映射关系如下:{0, 0} => 1、{0, 1} => 2、{0, 2} => 3、{1, 0} => 4、{1, 1} => 5、{1, 2} => 6。然后,这些元素在内存中的存储顺序由实现定义。张量常量具有以下限制:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)),其中:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)。has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)。
- (C2)
has_shape(tensor_literal, shape(tensor_type)),其中:has_shape(element_literal: Syntax, []) = true。has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])。- 否则,
false。
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
量化张量常量使用与张量常量相同的表示法来表示量化张量值,元素指定为其存储类型的常量。量化张量常量具有以下限制:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))。 - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))。
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
字符串字面值由使用 ASCII 字符和转义序列指定的字节组成。它们与编码无关,因此这些字节的解释由实现定义。字符串字面量的类型为 string。
操作/运算
abs
语义
对 operand 张量执行元素级绝对值运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于有符号整数:整数模数。
- 对于浮点数:来自 IEEE-754 的
abs。 - 对于复数:复数模。
- 对于量化类型:
dequantize_op_quantize(abs, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
有符号整数、浮点或复数类型的张量或每张量量化张量 | (C1-C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
有符号整数或浮点类型的张量,或每个张量的量化张量 | (C1-C2) |
限制条件
- (C1)
shape(result) = shape(operand)。 - (C2)
baseline_element_type(result)的定义如下:- 如果
is_complex(operand),则为complex_element_type(element_type(operand))。 - 否则为
baseline_element_type(operand)。
- 如果
示例
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
添加
语义
对两个张量 lhs 和 rhs 执行逐元素加法,并生成一个 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑或。
- 对于整数:整数加法。
- 对于浮点数:来自 IEEE-754 的
addition。 - 对于复数:复数加法。
- 对于量化类型:
dequantize_op_quantize(add, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或量化张量 | (C1-C6) |
| (I2) | rhs |
张量或量化张量 | (C1-C5)、(C7) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1-C7) |
限制条件
- 如果运算使用非量化张量:
- (C1)
type(lhs) = type(rhs) = type(result)。
- (C1)
- 如果操作使用量化张量:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)。 - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)。 - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)。 - (C6) 如果
is_per_axis_quantized(lhs),则quantization_dimension(lhs) = quantization_dimension(result)。 - (C7) 如果
is_per_axis_quantized(rhs),则quantization_dimension(rhs) = quantization_dimension(result)。
- (C2)
示例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
语义
确保生成 inputs 的操作在任何依赖于 result 的操作之前执行。执行此操作不会执行任何操作,它仅用于建立从 result 到 inputs 的数据依赖关系。
输入
| 标签 | 名称 | 类型 |
|---|---|---|
| (I1) | inputs |
可变数量的 token |
输出
| 名称 | 类型 |
|---|---|
result |
token |
示例
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
语义
在 StableHLO 进程网格中的每个进程组内,沿 all_gather_dim 连接来自每个进程的 operands 张量的值,并生成 results 张量。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0 and use_global_device_ids = false,则为cross_replica(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = false,则为cross_replica_and_partition(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = true,则为flattened_ids(replica_groups)。
之后,在每个 process_group 中:
operands...@receiver = [operand@sender for sender in process_group](适用于process_group中的所有receiver)。results...@process = concatenate(operands...@process, all_gather_dim)(适用于process_group中的所有process)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operands |
数量可变的张量或每个张量的量化张量 | (C1)、(C6) |
| (I2) | all_gather_dim |
类型为 si64 的常量 |
(C1)、(C6) |
| (I3) | replica_groups |
类型为 si64 的二维张量常量 |
(C2-C4) |
| (I4) | channel_id |
类型为 si64 的常量 |
(C5) |
| (I5) | use_global_device_ids |
类型为 i1 的常量 |
(C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C6) |
限制条件
- (C1)
0 <= all_gather_dim < rank(operands...)。 - (C2)
is_unique(replica_groups)。 - (C3)
size(replica_groups)的定义如下:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_replica_and_partition,则为num_replicas。 - 如果使用
flattened_ids,则为num_processes。
- 如果使用
- (C4)
0 <= replica_groups < size(replica_groups)。 - (C5) 如果
use_global_device_ids = true,则channel_id > 0。 - (C6)
type(results...) = type(operands...)除外:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)。
示例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
语义
在 StableHLO 进程网格中的每个进程组内,将缩减函数 computation 应用于每个进程的 operands 张量的值,并生成 results 张量。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0 and use_global_device_ids = false,则为cross_replica(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = false,则为cross_replica_and_partition(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = true,则为flattened_ids(replica_groups)。
之后,在每个 process_group 中:
- 对于某个二叉树
results...@process[result_index] = exec(schedule),schedule其中:exec(node)=computation(exec(node.left), exec(node.right))。exec(leaf)=leaf.value。
schedule是一个由实现定义的二叉树,其按顺序遍历为to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operands |
数量可变的张量或每个张量的量化张量 | (C5)、(C6) |
| (I2) | replica_groups |
可变数量的 1 维张量常量,类型为 si64 |
(C1-C3) |
| (I3) | channel_id |
类型为 si64 的常量 |
(C4) |
| (I4) | use_global_device_ids |
类型为 i1 的常量 |
(C4) |
| (I5) | computation |
函数 | (C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C6-C7) |
限制条件
- (C1)
is_unique(replica_groups)。 - (C2)
size(replica_groups)的定义如下:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_replica_and_partition,则为num_replicas。 - 如果使用
flattened_ids,则为num_processes。
- 如果使用
- (C3)
0 <= replica_groups < size(replica_groups)。 - (C4) 如果
use_global_device_ids = true,则channel_id > 0。 - (C5)
computation的类型为(tensor<E>, tensor<E>) -> (tensor<E>),其中is_promotable(element_type(operand), E)。 - (C6)
shape(results...) = shape(operands...)。 - (C7)
element_type(results...) = E。
示例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
语义
在 StableHLO 进程网格中的每个进程组内,沿 split_dimension 将 operands 张量的值拆分为若干部分,在进程之间分散拆分后的部分,沿 concat_dimension 连接分散的部分,并生成 results 张量。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0,则为cross_replica(replica_groups)。 - 如果
channel_id > 0,则为cross_partition(replica_groups)。
之后,在每个 process_group 中:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)适用于process_group中的所有sender。scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group],其中receiver_index = process_group.index(receiver)。results...@process = concatenate(scattered_parts...@process, concat_dimension)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operands |
数量可变的张量或每个张量的量化张量 | (C1-C3)、(C9) |
| (I2) | split_dimension |
类型为 si64 的常量 |
(C1)、(C2)、(C9) |
| (I3) | concat_dimension |
类型为 si64 的常量 |
(C3)、(C9) |
| (I4) | split_count |
类型为 si64 的常量 |
(C2)、(C4)、(C8)、(C9) |
| (I5) | replica_groups |
类型为 si64 的二维张量常量 |
(C5-C8) |
| (I6) | channel_id |
类型为 si64 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C9) |
限制条件
- (C1)
0 <= split_dimension < rank(operands...)。 - (C2)
dim(operands..., split_dimension) % split_count = 0。 - (C3)
0 <= concat_dimension < rank(operands...)。 - (C4)
0 < split_count。 - (C5)
is_unique(replica_groups)。 - (C6)
size(replica_groups)的定义如下:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_partition,则为num_partitions。
- 如果使用
- (C7)
0 <= replica_groups < size(replica_groups)。 - (C8)
dim(replica_groups, 1) = split_count。 - (C9)
type(results...) = type(operands...)但如果split_dimension != concat_dimension,则除外:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count。dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count。
示例
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
和
语义
对两个张量 lhs 和 rhs 执行按元素与运算,并生成一个 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑 AND。
- 对于整数:按位 AND。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
布尔值或整数类型的张量 | (C1) |
| (I2) | rhs |
布尔值或整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
布尔值或整数类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
语义
对 lhs 和 rhs 张量执行逐元素 atan2 运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
atan2。 - 对于复数:复数 atan2。
- 对于量化类型:
dequantize_op_quantize(atan2, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
| (I2) | rhs |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
示例
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
语义
计算 batch_norm_training 的多个输入的梯度(从 grad_output 反向传播),并生成 grad_operand、grad_scale 和 grad_offset 张量。更正式地说,此操作可以使用 Python 语法表示为对现有 StableHLO 操作的分解,如下所示:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
对于量化类型,执行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1-C3)、(C5) |
| (I2) | scale |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4)、(C5) |
| (I3) | mean |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4) |
| (I4) | variance |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4) |
| (I5) | grad_output |
浮点类型张量或逐张量量化张量 | (C2)、(C3) |
| (I6) | epsilon |
类型为 f32 的常量 |
|
| (I7) | feature_index |
类型为 si64 的常量 |
(C1)、(C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
grad_operand |
浮点类型张量或逐张量量化张量 | (C2)、(C3) |
grad_scale |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4) |
grad_offset |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4) |
限制条件
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、mean、variance、grad_output、grad_operand、grad_scale和grad_offset具有相同的baseline_element_type。 - (C3)
operand、grad_output和grad_operand具有相同的形状。 - (C4)
scale、mean、variance、grad_scale和grad_offset具有相同的形状。 - (C5)
size(scale) = dim(operand, feature_index)。
示例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
语义
对 operand 张量进行归一化处理,但 feature_index 维度除外,并生成 result 张量。更正式地说,此操作可以使用 Python 语法表示为对现有 StableHLO 操作的分解,如下所示:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
对于量化类型,执行 dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1-C7) |
| (I2) | scale |
浮点或按张量量化类型的 1 维张量 | (C2)、(C3) |
| (I3) | offset |
浮点或按张量量化类型的 1 维张量 | (C2)、(C4) |
| (I4) | mean |
浮点或按张量量化类型的 1 维张量 | (C5) |
| (I5) | variance |
浮点或按张量量化类型的 1 维张量 | (C2)、(C6) |
| (I6) | epsilon |
类型为 f32 的常量 |
|
| (I7) | feature_index |
类型为 si64 的常量 |
(C1)、(C3-C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点类型张量或逐张量量化张量 | (C2)、(C7) |
限制条件
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、offset、mean、variance和result具有相同的baseline_element_type。 - (C3)
size(scale) = dim(operand, feature_index)。 - (C4)
size(offset) = dim(operand, feature_index)。 - (C5)
size(mean) = dim(operand, feature_index)。 - (C6)
size(variance) = dim(operand, feature_index)。 - (C7)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
语义
计算除 feature_index 维度之外的所有维度上的平均值和方差,并对 operand 张量进行归一化处理,生成 output、batch_mean 和 batch_var 张量。更正式地说,此操作可以使用 Python 语法分解为现有的 StableHLO 操作,如下所示:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
对于量化类型,执行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
| (I2) | scale |
浮点或按张量量化的 1 维张量 | (C2)、(C3) |
| (I3) | offset |
浮点或按张量量化的 1 维张量 | (C2)、(C4) |
| (I4) | epsilon |
类型为 f32 的常量 |
(C1)、(C3-C6) |
| (I5) | feature_index |
类型为 si64 的常量 |
(C1)、(C3-C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
output |
浮点类型张量或逐张量量化张量 | (C7) |
batch_mean |
浮点或按张量量化的 1 维张量 | (C2)、(C5) |
batch_var |
浮点或按张量量化的 1 维张量 | (C2)、(C6) |
限制条件
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、offset、batch_mean、batch_var和output具有相同的baseline_element_type。 - (C3)
size(scale) = dim(operand, feature_index)。 - (C4)
size(offset) = dim(operand, feature_index)。 - (C5)
size(batch_mean) = dim(operand, feature_index)。 - (C6)
size(batch_var) = dim(operand, feature_index)。 - (C7)
baseline_type(output) = baseline_type(operand)。
示例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
语义
对 operand 张量执行位转换操作,并生成一个 result 张量,其中整个 operand 张量的位使用 result 张量的类型重新解释。
更正式地说,给定 E = element_type(operand)、E' = element_type(result) 和 R = rank(operand):
- 如果
num_bits(E') < num_bits(E),则为bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])。 - 如果
num_bits(E') > num_bits(E),则为bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])。 - 如果
num_bits(E') = num_bits(E),则为bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])。
bits 会返回给定值的内存中表示形式,并且其行为由实现定义,因为张量的确切表示形式由实现定义,并且元素类型的确切表示形式也由实现定义。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1-C2) |
限制条件
- (C1) 假定有
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)、E' = is_quantized(result) ? storage_type(result) : element_type(result)和R = rank(operand):- 如果为
num_bits(E') = num_bits(E),则为shape(result) = shape(operand)。 - 如果
num_bits(E') < num_bits(E): rank(result) = R + 1。dim(result, i) = dim(operand, i)适用于所有0 <= i < R。dim(result, R) * num_bits(E') = num_bits(E)。- 如果
num_bits(E') > num_bits(E): rank(result) = R - 1。dim(result, i) = dim(operand, i)适用于所有0 <= i < R。dim(operand, R - 1) * num_bits(E) = num_bits(E')。
- 如果为
- (C2) 如果
is_complex(operand) or is_complex(result),则is_complex(operand) and is_complex(result)。
示例
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
语义
通过复制 operand 张量中的数据来扩展输入张量的维度和/或秩,并生成 result 张量。更正式地说,result[result_index] = operand[operand_index],其中对于 axes(operand) 中的所有 d:
- 如果
dim(operand, d) = 1,则为operand_index[d] = 0。 - 否则为
operand_index[d] = result_index[broadcast_dimensions[d]]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C2)、(C5-C6) |
| (I2) | broadcast_dimensions |
类型为 si64 的 1 维张量常量 |
(C2-C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1)、(C3)、(C5-C6) |
限制条件
- (C1)
element_type(result)由以下人员提供:element_type(operand)(如果!is_per_axis_quantized(operand))。element_type(operand),但quantization_dimension(operand)、scales(operand)和zero_points(operand)可能分别与quantization_dimension(result)、scales(result)和zero_points(result)不同,否则。
- (C2)
size(broadcast_dimensions) = rank(operand)。 - (C3)
0 <= broadcast_dimensions < rank(result)。 - (C4)
is_unique(broadcast_dimensions)。 - (C5) 对于
axes(operand)中的所有d:dim(operand, d) = 1或dim(operand, d) = dim(result, broadcast_dimensions[d])。
- (C6) 如果
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。- 如果值为
dim(operand, quantization_dimension(operand)) = 1,则值为scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))。
示例
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
场景
语义
根据 index 的值,通过执行 branches 中的一个函数来生成输出。更正式地说,result = selected_branch(),其中:
- 如果
0 <= index < size(branches),则为selected_branch = branches[index]。 - 否则为
selected_branch = branches[-1]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | index |
类型为 si32 的 0 维张量 |
|
| (I2) | branches |
可变数量的函数 | (C1-C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量、量化张量或令牌 | (C4) |
限制条件
- (C1)
0 < size(branches)。 - (C2)
input_types(branches...) = []。 - (C3)
same(output_types(branches...))。 - (C4)
type(results...) = output_types(branches[0])。
示例
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
语义
对 operand 张量执行逐元素立方根运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
rootn(x, 3)。 - 对于复数:复数立方根。
- 对于量化类型:
dequantize_op_quantize(cbrt, operand, type(result))
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
语义
对 operand 张量执行逐元素向上取整运算,并生成 result 张量。
实现 IEEE-754 规范中的 roundToIntegralTowardPositive 操作。对于量化类型,执行 dequantize_op_quantize(ceil, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点类型张量或逐张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
语义
计算一批矩阵的 Cholesky 分解。
更正式地说,对于 index_space(result) 中的所有 i,result[i0, ..., iR-3, :, :] 是 a[i0, ..., iR-3, :, :] 的 Cholesky 分解,形式为下三角矩阵(如果 lower 为 true)或上三角矩阵(如果 lower 为 false)。
相反三角形(即严格上三角或严格下三角)中的输出值由实现定义。
如果存在输入矩阵不是 Hermitian 正定矩阵的 i,则行为未定义。
对于量化类型,执行 dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | a |
浮点或复杂类型的张量或每个张量的量化张量 | (C1-C3) |
| (I2) | lower |
类型为 i1 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(a) = baseline_type(result)。 - (C2)
2 <= rank(a)。 - (C3)
dim(a, -2) = dim(a, -1)。
示例
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
限制取值范围
语义
将 operand 张量的每个元素限制在最小值和最大值之间,并生成 result 张量。更正式地说,result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element),其中 min_element = rank(min) = 0 ? min[] : min[result_index],max_element = rank(max) = 0 ? max[] : max[result_index]。对于量化类型,执行 dequantize_op_quantize(clamp, min, operand, max, type(result))。
对复数施加排序会涉及令人惊讶的语义,因此我们计划在未来移除对该操作的复数支持 (#560)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | min |
张量或按张量量化的张量 | (C1)、(C3) |
| (I2) | operand |
张量或按张量量化的张量 | (C1-C4) |
| (I3) | max |
张量或按张量量化的张量 | (C2)、(C3) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C4) |
限制条件
- (C1)
rank(min) = 0 or shape(min) = shape(operand)。 - (C2)
rank(max) = 0 or shape(max) = shape(operand)。 - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)。 - (C4)
baseline_type(operand) = baseline_type(result)。
示例
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
语义
在 StableHLO 进程网格中的每个进程组内,将源进程中的 operand 张量的值发送到目标进程,并生成 result 张量。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0,则为cross_replica(replica_groups)。 - 如果
channel_id > 0,则为cross_partition(replica_groups)。
之后,result@process 由以下公式给出:
- 如果存在
i使得进程处于process_groups[i]中,则为operand@process_groups[i, 0]。 - 否则为
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C3) |
| (I2) | replica_groups |
可变数量的 1 维张量常量,类型为 si64 |
(C1)、(C2) |
| (I3) | channel_id |
类型为 si64 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C3) |
限制条件
- (C1)
is_unique(replica_groups)。 - (C2)
0 <= replica_groups < N,其中N定义如下:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_partition,则为num_partitions。
- 如果使用
- (C3)
type(result) = type(operand)。
示例
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
语义
在 StableHLO 进程网格中的每个进程组内,将源进程中的 operand 张量的值发送到目标进程,并生成 result 张量。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0,则为cross_replica(source_target_pairs)。 - 如果
channel_id > 0,则为cross_partition(source_target_pairs)。
之后,result@process 由以下公式给出:
operand@process_groups[i, 0],如果存在i使得process_groups[i, 1] = process。- 否则为
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C5) |
| (I2) | source_target_pairs |
类型为 si64 的二维张量常量 |
(C1-C4) |
| (I3) | channel_id |
类型为 si64 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N,其中N定义为:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_partition,则为num_partitions。
- 如果使用
- (C5)
type(result) = type(operand)。
示例
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
比较
语义
根据 comparison_direction 和 compare_type 对 lhs 和 rhs 张量执行逐元素比较,并生成 result 张量。
comparison_direction 和 compare_type 的值具有以下语义:
对于布尔值和整数元素类型:
EQ:lhs = rhs。NE:lhs != rhs。GE:lhs >= rhs。GT:lhs > rhs。LE:lhs <= rhs。LT:lhs < rhs。
对于具有 compare_type = FLOAT 的浮点元素类型,相应操作会实现以下 IEEE-754 操作:
EQ:compareQuietEqual。NE:compareQuietNotEqual。GE:compareQuietGreaterEqual。GT:compareQuietGreater。LE:compareQuietLessEqual。LT:compareQuietLess。
对于具有 compare_type = TOTALORDER 的浮点元素类型,相应操作使用 IEEE-754 中的 totalOrder 和 compareQuietEqual 操作的组合。
对于复杂元素类型,系统会使用提供的 comparison_direction 和 compare_type 对 (real, imag) 对进行字典序比较。对复数施加排序会涉及令人惊讶的语义,因此我们计划在未来移除当 comparison_direction 为 GE、GT、LE 或 LT 时的复数支持 (#560)。
对于量化类型,执行 dequantize_compare(lhs, rhs,
comparison_direction)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1-C3) |
| (I2) | rhs |
张量或按张量量化的张量 | (C1-C2) |
| (I3) | comparison_direction |
EQ、NE、GE、GT、LE 和 LT 的枚举 |
|
| (I4) | compare_type |
FLOAT、TOTALORDER、SIGNED 和 UNSIGNED 的枚举 |
(C3) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
布尔值类型的张量 | (C2) |
限制条件
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)。 - (C2)
shape(lhs) = shape(rhs) = shape(result)。 - (C3)
compare_type的定义如下:- 如果
is_signed_integer(element_type(lhs)),则为SIGNED。 - 如果
is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)),则为UNSIGNED。 - 如果
is_float(element_type(lhs)),则为FLOAT或TOTALORDER。 - 如果
is_complex(element_type(lhs)),则为FLOAT。
- 如果
示例
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
复杂
语义
根据实数值和虚数值对 lhs 和 rhs 执行逐元素转换,生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
类型为 f32 或 f64 的张量 |
(C1-C3) |
| (I2) | rhs |
类型为 f32 或 f64 的张量 |
(C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
复杂类型的张量 | (C2)、(C3) |
限制条件
- (C1)
type(lhs) = type(rhs)。 - (C2)
shape(result) = shape(lhs)。 - (C3)
element_type(result)的类型为complex<E>,其中E = element_type(lhs)。
示例
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
复合型
语义
封装了由其他 StableHLO 操作组成的运算,接受 inputs 和 composite_attributes 并生成 results。相应操作的语义由 decomposition 属性实现。composite 操作可以替换为其分解,而不会更改程序语义。如果内嵌分解无法提供相同的操作语义,请优先使用 custom_call。
version 字段(默认为 0)用于表示复合语义何时发生变化。
输入
| 标签 | 名称 | 类型 |
|---|---|---|
| (I1) | inputs |
可变数量的值 |
| (I2) | name |
类型为 string 的常量 |
| (I3) | composite_attributes |
属性字典 |
| (I4) | decomposition |
类型为 string 的常量 |
| (I5) | version |
类型为 si32 的常量 |
输出
| 名称 | 类型 |
|---|---|
results |
可变数量的值 |
限制条件
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
示例
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
语义
沿 dimension 维度按给定实参的相同顺序串联 inputs,并生成 result 张量。更正式地说,result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1],其中:
id = d0 + ... + dk-1 + kd。d等于dimension,而d0、... 是inputs的第d个维度的大小。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1-C6) |
| (I2) | dimension |
类型为 si64 的常量 |
(C2)、(C4)、(C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C5-C6) |
限制条件
- (C1)
same(element_type(inputs...))。 - (C2)
same(shape(inputs...))(dim(inputs..., dimension)除外)。 - (C3)
0 < size(inputs)。 - (C4)
0 <= dimension < rank(inputs[0])。 - (C5)
element_type(result) = element_type(inputs[0])。 - (C6)
shape(result) = shape(inputs[0])但以下情况除外:dim(result, dimension) = dim(inputs[0], dimension) + ...。
示例
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
常量
语义
根据常量 value 生成 output 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | value |
常量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
output |
张量或量化张量 | (C1) |
限制条件
- (C1)
type(value) = type(output)。
示例
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
转化
语义
对 operand 张量执行从一种元素类型到另一种元素类型的逐元素转换,并生成 result 张量。
对于boolean-to-any-supported-type的转换,值 false 会转换为零,值 true 会转换为一。对于any-supported-type-to-boolean的转换,零值会转换为 false,非零值会转换为 true。如需了解此功能如何处理复杂类型,请参阅下文。
对于涉及整数到整数、整数到浮点数或浮点数到浮点数的转换,如果源值可以用目标类型精确表示,则结果值就是该精确表示。否则,行为尚待确定 (#180)。
对于涉及floating-point-to-integer的转换,系统会截断小数部分。如果截断后的值无法用目标类型表示,则行为待定 (#180)。
涉及复数到复数的转换遵循与浮点数到浮点数转换相同的行为,用于转换实部和虚部。
对于从复数到任何其他类型和从任何其他类型到复数的转换,系统会分别忽略源虚数值或将目标虚数值归零。实部的转换遵循浮点转换。
原则上,此操作可以表示反量化(从量化张量转换为常规张量)、量化(从常规张量转换为量化张量)和重新量化(在量化张量之间转换),但目前我们有专门的操作来处理这些情况 - uniform_dequantize 用于第一种使用情形,uniform_quantize 用于第二种和第三种使用情形。未来,这两个操作可能会合并为 convert (#1576)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量 | (C1) |
限制条件
- (C1)
shape(operand) = shape(result)。
示例
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
卷积
语义
计算 lhs 的窗口与 rhs 的切片之间的点积,并生成 result。下图通过一个具体示例展示了如何根据 lhs 和 rhs 计算 result 中的元素。
更正式地说,请考虑从 lhs 的角度重新构建输入,以便能够表达 lhs 的窗口:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))。lhs_window_strides = lhs_shape(1, window_strides, 1)。lhs_padding = lhs_shape([0, 0], padding, [0, 0])。lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)。lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)。
此重构使用以下辅助函数:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])。result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])。permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1],其中j[d] = i[permutation[d]]。
如果 feature_group_count = 1 和 batch_group_count = 1,则对于 index_space(dim(result, output_spatial_dimensions...)) 中的所有 output_spatial_index,result[result_shape(:, output_spatial_index, :)] = dot_product,其中:
padding_value = constant(0, element_type(lhs))。padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)。lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides。lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)。reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。此功能似乎未被使用,因此我们计划在未来将其移除 (#1181)。dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])。
如果 feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension)。rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)。results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)。result = concatenate(results, output_feature_dimension)。
如果 batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension)。rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)。results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)。result = concatenate(results, output_feature_dimension)。
对于量化类型,执行 dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))。
对于混合量化类型,执行 hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1)、(C10-C11)、(C14)、(C25)、(C27-C28)、(C31-C32)、(C34) |
| (I2) | rhs |
张量或量化张量 | (C1)、(C14-C16)、(C25)、(C27-C29)、(C31-C34) |
| (I3) | window_strides |
类型为 si64 的 1 维张量常量 |
(C2-C3)、(C25) |
| (I4) | padding |
类型为 si64 的二维张量常量 |
(C4)、(C25) |
| (I5) | lhs_dilation |
类型为 si64 的 1 维张量常量 |
(C5-C6)、(C25) |
| (I6) | rhs_dilation |
类型为 si64 的 1 维张量常量 |
(C7-C8)、(C25) |
| (I7) | window_reversal |
类型为 i1 的 1 维张量常量 |
(C9) |
| (I8) | input_batch_dimension |
类型为 si64 的常量 |
(C10)、(C13)、(C25) |
| (I9) | input_feature_dimension |
类型为 si64 的常量 |
(C11)、(C13-C14) |
| (I10) | input_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C12)、(C13)、(C25) |
| (I11) | kernel_input_feature_dimension |
类型为 si64 的常量 |
(C14)、(C18) |
| (I12) | kernel_output_feature_dimension |
类型为 si64 的常量 |
(C15-C16)、(C18)、(C25)、(C29) |
| (I13) | kernel_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C17-C18)、(C25) |
| (I14) | output_batch_dimension |
类型为 si64 的常量 |
(C20)、(C25) |
| (I15) | output_feature_dimension |
类型为 si64 的常量 |
(C20)、(C25)、(C30) |
| (I16) | output_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C19-C20)、(C25) |
| (I17) | feature_group_count |
类型为 si64 的常量 |
(C11)、(C14)、(C16)、(C21)、(C23) |
| (I18) | batch_group_count |
类型为 si64 的常量 |
(C10)、(C15)、(C22)、(C23)、(C25) |
| (I19) | precision_config |
DEFAULT、HIGH 和 HIGHEST 的可变数量的枚举 |
(C24) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C25-C28)、(C30)、(C32-34) |
限制条件
- (C1)
N = rank(lhs) = rank(rhs)。 - (C2)
size(window_strides) = N - 2。 - (C3)
0 < window_strides。 - (C4)
shape(padding) = [N - 2, 2]。 - (C5)
size(lhs_dilation) = N - 2。 - (C6)
0 < lhs_dilation。 - (C7)
size(rhs_dilation) = N - 2。 - (C8)
0 < rhs_dilation。 - (C9)
size(window_reversal) = N - 2。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0。 - (C12)
size(input_spatial_dimensions) = N - 2。 - (C13) 给定
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions)。0 <= input_dimensions < N。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。 - (C17)
size(kernel_spatial_dimensions) = N - 2。 - (C18) 给定
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions)。0 <= kernel_dimensions < N。
- (C19)
size(output_spatial_dimensions) = N - 2。 - (C20) 给定
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions)。0 <= output_dimensions < N。
- (C21)
0 < feature_group_count。 - (C22)
0 < batch_group_count。 - (C23)
feature_group_count = 1 or batch_group_count = 1。 - (C24)
size(precision_config) = 2。 - (C25)
dim(result, result_dim)的定义如下:- 如果
result_dim = output_batch_dimension,则为dim(lhs, input_batch_dimension) / batch_group_count。 - 如果
result_dim = output_feature_dimension,则为dim(rhs, kernel_output_feature_dimension)。 - 否则为
num_windows,其中: output_spatial_dimensions[spatial_dim] = result_dim。lhs_dim = input_spatial_dimensions[spatial_dim]。rhs_dim = kernel_spatial_dimensions[spatial_dim]。dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1。padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]。dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1。is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]。num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
- 如果
- (C26)
rank(result) = N。 - 如果运算使用非量化张量:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)。
- (C27)
- 如果操作使用量化张量:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C29) 如果
is_per_axis_quantized(rhs),则quantization_dimension(rhs) = kernel_output_feature_dimension。 - (C30) 如果
is_per_axis_quantized(result),则quantization_dimension(result) = output_feature_dimension。 - 如果
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs)。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C33) 如果
is_per_tensor_quantized(rhs),则is_per_tensor_quantized(result)。 - 如果
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C28)
示例
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
余弦
语义
对 operand 张量执行逐元素余弦运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
cos。 - 对于复数:复余弦。
- 对于量化类型:
dequantize_op_quantize(cosine, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
语义
逐元素计算 operand 张量中前导零位的数量,并生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数类型的张量 | (C1) |
限制条件
- (C1)
type(operand) = type(result)。
示例
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
语义
封装了实现定义的运算 call_target_name,该运算接受 inputs 和 called_computations 并生成 results。has_side_effect、backend_config 和 api_version 可用于提供其他实现定义的元数据。
目前,此操作包含相当杂乱的元数据集合,反映了 XLA 编译器中对应操作的自然演变。未来,我们计划统一此元数据 (#741)。
输入
| 标签 | 名称 | 类型 |
|---|---|---|
| (I1) | inputs |
可变数量的值 |
| (I2) | call_target_name |
类型为 string 的常量 |
| (I3) | has_side_effect |
类型为 i1 的常量 |
| (I4) | backend_config |
类型为 string 的常量或属性字典 |
| (I5) | api_version |
类型为 si32 的常量 |
| (I6) | called_computations |
string 类型的可变数量的常量 |
| (I7) | output_operand_aliases |
指定输出和操作数中的别名部分 |
输出
| 名称 | 类型 |
|---|---|
results |
可变数量的值 |
(XLA GPU 支持)特殊 custom_call 目标
有三个与 buffer 类型相关的特殊 call_target_name:CreateBuffer 创建未初始化的 buffer,Pin 创建已初始化的 buffer,Unpin 释放 buffer 并返回 buffer 的内容。
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version = 4 : i32,
} : () -> memref<4xf64>
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin",
api_version = 4 : i32,
} : (tensor<4xf64>) -> memref<4xf64>
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_name = "Unpin",
api_version = 4 : i32,
} : (memref<4xf64>) -> tensor<4xf64>
别名
某些 custom_call 操作可能需要输出中的一部分和操作数中的一部分共享同一内存。这可以通过 output_operand_aliases 来表示。别名对表示法包含一个表示输出部分的输出元组索引列表,以及一个 operand_index 和一个表示运算数部分的运算数元组索引列表。如果相应类型不是 tuple 类型,则输出或操作数元组索引的列表为空;对于任意嵌套的元组类型,该列表可以任意长。这与 XLA 别名表示法类似。
别名对中的输出部分和输入部分必须具有相同的类型。对于不是对 CreateBuffer、Pin 和 Unpin 的调用的 custom_call 操作,buffer 实参最多可出现在一个别名对中,而 buffer 输出必须出现在一个别名对中。
示例
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases = [
#stablehlo.output_operand_alias<output_tuple_indices = [],
operand_index = 0,
operand_tuple_indices = []>]
} : (memref<4xf64>) -> memref<4xf64>
除
语义
对被除数 lhs 张量和除数 rhs 张量执行逐元素除法,并生成 result 张量。根据元素类型,执行以下操作:
- 对于整数:整数除法,生成代数商,并舍弃任何小数部分。
- 对于浮点数:来自 IEEE-754 的
division。 - 对于复数:复数除法。
- 对于量化类型:
dequantize_op_quantize(divide, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数、浮点或复数类型的张量或按张量量化的张量 | (C1) |
| (I2) | rhs |
整数、浮点或复数类型的张量或按张量量化的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
示例
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
语义
计算 lhs 的切片与 rhs 的切片之间的点积,并生成 result 张量。
更正式地说,result[result_index] = dot_product,其中:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]。rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]。result_batching_index + result_lhs_index + result_rhs_index = result_index其中size(result_batching_index) = size(lhs_batching_dimensions)、size(result_lhs_index) = size(lhs_result_dimensions)和size(result_rhs_index) = size(rhs_result_dimensions)。transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)。transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])。reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))。transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)。transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])。reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))。dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))。
对于量化类型,执行 dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))。
对于混合量化类型,执行 hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)。
precision_config 用于控制在加速器后端上进行计算时速度与准确性之间的权衡。可以是以下值之一(目前,这些枚举值的语义尚未完全指定,但我们计划在 #755 中解决此问题):
DEFAULT:计算速度最快,但对原始数字的近似程度最低。HIGH:计算速度较慢,但能更准确地逼近原始数字。HIGHEST:计算速度最慢,但最接近原始数字。
DotAlgorithm 用于定义实现点运算的算法的主要属性,同时还定义了精度。如果设置了算法属性字段,则 precision_config 必须为 DEFAULT。DotAlgorithms 没有默认值,因为默认参数是由实现定义的。因此,所有点算法字段都可以设置为 None,以指定空点算法,该算法将改用 precision_config 值。
DotAlgorithm 字段包括:
lhs_precision_type和rhs_precision_type,运算的左侧和右侧四舍五入到的精度。精度类型与输入和输出的存储类型无关。accumulation_type用于累积的精度。- 当我们在执行将左侧实参和/或右侧实参分解为多个组成部分并对这些值执行多个“原始”点运算的算法时,会应用
lhs_component_count、rhs_component_count和num_primitive_operations- 通常是为了模拟更高的精度(例如利用 bfloat16 人工智能数据类型进行更高精度的计算:bf16_6x tf32_3x 等)。对于没有分解的算法,这些值应设置为1。 allow_imprecise_accumulation用于指定是否允许在某些步骤(例如CUBLASLT_MATMUL_DESC_FAST_ACCUM)中以较低精度进行累积。
DotAlgorithm 属性示例:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
具体支持哪些组合由实现方决定。一般来说,StableHLO 的使用者无法保证每种算法都受每种加速器类型的支持。如果不支持给定的算法,则应引发错误,而不是回退到替代算法。StableHLO 验证将尽力进行验证,防止使用已知在任何硬件上都不受支持的算法。
如需查看一些受支持的算法值,请参阅 xla_data.proto > Algorithm。工单 #2483 记录了创建关于后端支持的算法的集中式文档的计划。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C5-C6)、(C9-C10)、(C12-C14)、(C17-C18)、(C20) |
| (I2) | rhs |
张量或量化张量 | (C7-C10)、(C12-C20) |
| (I3) | lhs_batching_dimensions |
类型为 si64 的 1 维张量常量 |
(C1)、(C3)、(C5)、(C9)、(C12) |
| (I4) | rhs_batching_dimensions |
类型为 si64 的 1 维张量常量 |
(C1)、(C4)、(C7)、(C9) |
| (I5) | lhs_contracting_dimensions |
类型为 si64 的 1 维张量常量 |
(C2)、(C3)、(C6)、(C10) |
| (I6) | rhs_contracting_dimensions |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C8)、(C10)、(C16) |
| (I7) | precision_config |
DEFAULT、HIGH 和 HIGHEST 的可变数量的枚举 |
(C11)、(C21) |
| (I8) | lhs_precision_type |
FloatType 或 TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType 或 TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType 或 TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
类型为 si32 的常量 |
(C21)、(C22) |
| (I12) | rhs_component_count |
类型为 si32 的常量 |
(C21)、(C23) |
| (I13) | num_primitive_operations |
类型为 si32 的常量 |
(C21)、(C24) |
| (I14) | allow_imprecise_accumulation |
类型为 bool 的常量 |
(C21) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C12)、(C14)、(C18-C20) |
限制条件
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)。 - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)。 - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)。 - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)。 - (C5)
0 <= lhs_batching_dimensions < rank(lhs)。 - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)。 - (C7)
0 <= rhs_batching_dimensions < rank(rhs)。 - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)。 - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)。 - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)。 - (C11)
size(precision_config) = 2。 - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)。 - 如果运算使用非量化张量:
- (C13)
element_type(lhs) = element_type(rhs)。
- (C13)
- 如果操作使用量化张量:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C15)
zero_points(rhs) = 0。 - (C16) 如果
is_per_axis_quantized(rhs),则quantization_dimension(rhs)不在rhs_contracting_dimensions中。 - 如果
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs)。 - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C19) 如果
is_per_tensor_quantized(rhs),则is_per_tensor_quantized(result)。 - 如果
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C14)
- 如果
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT。 - (C22)
0 < lhs_component_count。 - (C23)
0 < rhs_component_count。 - (C24)
0 < num_primitive_operations。
- (C21)
示例
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
语义
此操作在功能上与 broadcast_in_dim 操作相同,但结果形状是通过 output_dimensions 动态指定的。
该操作还接受可选属性 known_expanding_dimensions、known_nonexpanding_dimensions,以表达有关维度扩展行为的静态知识。如果未指定,则假定所有维度都可能扩大。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C2)、(C5-C6)、(C9) |
| (I2) | output_dimensions |
整数类型的一维张量 | (C7) |
| (I3) | broadcast_dimensions |
整数类型的一维常量张量 | (C2-C6) |
| (I4) | known_expanding_dimensions |
整数类型的一维常量张量 | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
整数类型的一维常量张量 | (C8-C9) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1)、(C3)、(C5-C7) |
限制条件
- (C1)
element_type(result)由以下人员提供:element_type(operand)(如果!is_per_axis_quantized(operand))。element_type(operand),但quantization_dimension(operand)、scales(operand)和zero_points(operand)可能分别与quantization_dimension(result)、scales(result)和zero_points(result)不同,否则。
- (C2)
size(broadcast_dimensions) = rank(operand)。 - (C3)
0 <= broadcast_dimensions < rank(result)。 - (C4)
is_unique(broadcast_dimensions)。 - (C5) 对于
axes(operand)中的所有d:dim(operand, d) = 1或dim(operand, d) = dim(result, broadcast_dimensions[d])。
- (C6) 如果
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。- 如果值为
dim(operand, quantization_dimension(operand)) = 1,则值为scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))。
- (C7)
size(output_dimensions) = rank(result)。 - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)。 - (C9)
0 <= known_expanding_dimensions < rank(operand)。 - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)。
示例
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
语义
此操作在功能上与卷积操作相同,但填充是通过 padding 动态指定的。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1)、(C10-C11)、(C14)、(C25)、(C26-C27)、(C30-C31)、(C33) |
| (I2) | rhs |
张量或量化张量 | (C1)、(C14-C16)、(C26-C28)、(C30-C33) |
| (I3) | padding |
整数类型的二维张量 | (C4) |
| (I4) | window_strides |
类型为 si64 的 1 维张量常量 |
(C2-C3) |
| (I5) | lhs_dilation |
类型为 si64 的 1 维张量常量 |
(C5-C6) |
| (I6) | rhs_dilation |
类型为 si64 的 1 维张量常量 |
(C7-C8) |
| (I7) | window_reversal |
类型为 i1 的 1 维张量常量 |
(C9) |
| (I8) | input_batch_dimension |
类型为 si64 的常量 |
(C10)、(C13) |
| (I9) | input_feature_dimension |
类型为 si64 的常量 |
(C11)、(C13-C14) |
| (I10) | input_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C12)、(C13) |
| (I11) | kernel_input_feature_dimension |
类型为 si64 的常量 |
(C14)、(C18) |
| (I12) | kernel_output_feature_dimension |
类型为 si64 的常量 |
(C15-C16)、(C18)、(C28) |
| (I13) | kernel_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C17-C18) |
| (I14) | output_batch_dimension |
类型为 si64 的常量 |
(C20) |
| (I15) | output_feature_dimension |
类型为 si64 的常量 |
(C20)、(C29) |
| (I16) | output_spatial_dimensions |
类型为 si64 的 1 维张量常量 |
(C19-C20) |
| (I17) | feature_group_count |
类型为 si64 的常量 |
(C11)、(C14)、(C16)、(C21)、(C23) |
| (I18) | batch_group_count |
类型为 si64 的常量 |
(C10)、(C15)、(C22)、(C23) |
| (I19) | precision_config |
DEFAULT、HIGH 和 HIGHEST 的可变数量的枚举 |
(C24) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C25-C27)、(C29)、(C31-C33) |
限制条件
- (C1)
N = rank(lhs) = rank(rhs)。 - (C2)
size(window_strides) = N - 2。 - (C3)
0 < window_strides。 - (C4)
shape(padding) = [N - 2, 2]。 - (C5)
size(lhs_dilation) = N - 2。 - (C6)
0 < lhs_dilation。 - (C7)
size(rhs_dilation) = N - 2。 - (C8)
0 < rhs_dilation。 - (C9)
size(window_reversal) = N - 2。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0。 - (C12)
size(input_spatial_dimensions) = N - 2。 - (C13) 给定
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions)。0 <= input_dimensions < N。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。 - (C17)
size(kernel_spatial_dimensions) = N - 2。 - (C18) 给定
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions)。0 <= kernel_dimensions < N。
- (C19)
size(output_spatial_dimensions) = N - 2。 - (C20) 给定
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions)。0 <= output_dimensions < N。
- (C21)
0 < feature_group_count。 - (C22)
0 < batch_group_count。 - (C23)
feature_group_count = 1 or batch_group_count = 1。 - (C24)
size(precision_config) = 2。 - (C25)
dim(result, result_dim)的定义如下:- 如果
result_dim = output_batch_dimension,则为dim(lhs, input_batch_dimension) / batch_group_count。 - 如果
result_dim = output_feature_dimension,则为dim(rhs, kernel_output_feature_dimension)。 - 否则为
num_windows,其中: output_spatial_dimensions[spatial_dim] = result_dim。lhs_dim = input_spatial_dimensions[spatial_dim]。rhs_dim = kernel_spatial_dimensions[spatial_dim]。dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1。padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]。dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1。is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]。num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
- 如果
- (C26)
rank(result) = N。 - 如果运算使用非量化张量:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)。
- (C27)
- 如果操作使用量化张量:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C29) 如果
is_per_axis_quantized(rhs),则quantization_dimension(rhs) = kernel_output_feature_dimension。 - (C30) 如果
is_per_axis_quantized(result),则quantization_dimension(result) = output_feature_dimension。 - 如果
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs)。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C33) 如果
is_per_tensor_quantized(rhs),则is_per_tensor_quantized(result)。 - 如果
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C28)
示例
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
语义
此操作在功能上与 gather 操作完全相同,只是 slice_sizes 是以值的形式动态指定的。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C7)、(C10-C12)、(C14) |
| (I2) | start_indices |
整数类型的张量 | (C2)、(C3)、(C13) |
| (I3) | slice_sizes |
整数类型的一维张量 | (C8)、(C11-C13) |
| (I4) | offset_dims |
类型为 si64 的 1 维张量常量 |
(C1)、(C4-C5)、(C13) |
| (I5) | collapsed_slice_dims |
类型为 si64 的 1 维张量常量 |
(C1)、(C6-C8)、(C13) |
| (I6) | start_index_map |
类型为 si64 的 1 维张量常量 |
(C3)、(C9)、(C10) |
| (I7) | index_vector_dim |
类型为 si64 的常量 |
(C2)、(C3)、(C13) |
| (I8) | indices_are_sorted |
类型为 i1 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C5)、(C13-C14) |
限制条件
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)。 - (C2)
0 <= index_vector_dim <= rank(start_indices)。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)。 - (C5)
0 <= offset_dims < rank(result)。 - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)。 - (C7)
0 <= collapsed_slice_dims < rank(operand)。 - (C8)
slice_sizes[collapsed_slice_dims...] <= 1。 - (C9)
is_unique(start_index_map)。 - (C10)
0 <= start_index_map < rank(operand)。 - (C11)
size(slice_sizes) = rank(operand)。 - (C12)
0 <= slice_sizes <= shape(operand)。 - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)其中:batch_dim_sizes = shape(start_indices),但其中不包含与index_vector_dim对应的start_indices的维度大小。offset_dim_sizes = shape(slice_sizes),但slice_sizes中与collapsed_slice_dims对应的维度大小不包含在内。combine将batch_dim_sizes放置在与batch_dims对应的轴上,并将offset_dim_sizes放置在与offset_dims对应的轴上。
- (C14)
element_type(operand) = element_type(result)。
示例
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
语义
此操作在功能上与 iota 操作完全相同,但结果形状是通过 output_shape 动态指定的。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | output_shape |
整数类型的一维张量 | (C1)、(C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C2) |
限制条件
- (C1)
0 <= iota_dimension < size(output_shape)。 - (C2)
rank(result) = size(output_shape)。
示例
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
语义
此操作在功能上与 pad 操作完全相同,但 edge_padding_low、edge_padding_high 和 interior_padding 是以值的形式动态指定的。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C2)、(C4) |
| (I2) | padding_value |
0 维张量或每个张量的量化张量 | (C1) |
| (I3) | edge_padding_low |
整数类型的一维张量 | (C1)、(C4) |
| (I4) | edge_padding_high |
整数类型的一维张量 | (C1)、(C4) |
| (I5) | interior_padding |
整数类型的一维张量 | (C2-C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C3-C6) |
限制条件
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。 - (C3)
0 <= interior_padding。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
示例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
语义
此操作在功能上与 reshape 操作完全相同,但结果形状是通过 output_shape 动态指定的。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C3) |
| (I2) | output_shape |
整数类型的一维张量 | (C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1-C4) |
限制条件
- (C1)
element_type(result)由以下人员提供:element_type(operand)(如果!is_per_axis_quantized(operand))。element_type(operand),但quantization_dimension(operand)和quantization_dimension(result)可能有所不同,其他方面则相同。
- (C2)
size(operand) = size(result)。 - (C3) 如果
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))。reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
- (C4)
size(output_shape) = rank(result)。
示例
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
语义
使用动态计算的起始索引从 operand 中提取切片,并生成 result 张量。start_indices 包含每个维度(可能会进行调整)的切片的起始索引,slice_sizes 包含每个维度的切片大小。更正式地说,result[result_index] = operand[operand_index],其中:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)。operand_index = adjusted_start_indices + result_index。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C2)、(C4) |
| (I2) | start_indices |
数量可变的 0 维整数类型张量 | (C2)、(C3) |
| (I3) | slice_sizes |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1)、(C5) |
限制条件
- (C1)
element_type(operand) = element_type(result)。 - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)。 - (C3)
same(type(start_indices...))。 - (C4)
0 <= slice_sizes <= shape(operand)。 - (C5)
shape(result) = slice_sizes。
示例
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
语义
生成一个 result 张量,该张量与 operand 张量相同,只是从 start_indices 开始的切片已使用 update 中的值进行更新。更正式地说,result[result_index] 的定义如下:
update[update_index]如果0 <= update_index < shape(update),其中:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))。update_index = result_index - adjusted_start_indices。
- 否则为
operand[result_index]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1-C4)、(C6) |
| (I2) | update |
张量或按张量量化的张量 | (C2)、(C3)、(C6) |
| (I3) | start_indices |
数量可变的 0 维整数类型张量 | (C4)、(C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
type(operand) = type(result)。 - (C2)
element_type(update) = element_type(operand)。 - (C3)
rank(update) = rank(operand)。 - (C4)
size(start_indices) = rank(operand)。 - (C5)
same(type(start_indices...))。 - (C6)
0 <= shape(update) <= shape(operand)。
示例
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
指数函数
语义
对 operand 张量执行逐元素指数运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
exp。 - 对于复数:复指数。
- 对于量化类型:
dequantize_op_quantize(exponential, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
语义
对 operand 张量执行逐元素指数减 1 运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
expm1。 - 对于复数:复指数减 1。
- 对于量化类型:
dequantize_op_quantize(exponential_minus_one, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
语义
针对实数和复数输入/输出执行正向和反向傅里叶转换。
fft_type 是以下值之一:
FFT:正向复数到复数 FFT。IFFT:逆复数到复数 FFT。RFFT:实数到复数正向 FFT。IRFFT:实数到复数 FFT 的逆运算(即输入为复数,输出为实数)。
更正式地说,假设函数 fft 以复杂类型的 1 维张量作为输入,生成相同类型的 1 维张量作为输出,并计算离散傅里叶转换:
对于 fft_type = FFT,result 定义为一系列 L 计算的最终结果,其中 L = size(fft_length)。例如,对于 L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :])。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])。result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
此外,假设函数 ifft 具有相同的类型签名,并计算 fft 的逆:
对于 fft_type = IFFT,result 定义为 fft_type = FFT 计算的逆运算。例如,对于 L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])。result[i0, ..., :] = ifft(result2[i0, ..., :])。
此外,假设有一个函数 rfft,该函数接受浮点类型的一维张量,生成具有相同浮点语义的复数类型的一维张量,并按如下方式运行:
rfft(real_operand) = truncated_result,其中complex_operand... = (real_operand..., 0.0)。complex_result = fft(complex_operand)。truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]。
(当为实数运算数计算离散傅里叶转换时,结果的前 N/2 + 1 个元素可明确定义结果的其余部分,因此 rfft 的结果会被截断,以避免计算冗余元素)。
对于 fft_type = RFFT,result 定义为一系列 L 计算的最终结果,其中 L = size(fft_length)。例如,对于 L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :])。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])。result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
最后,假设函数 irfft 具有相同的类型签名并计算 rfft 的逆:
对于 fft_type = IRFFT,result 定义为 fft_type = RFFT 计算的逆运算。例如,对于 L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])。result[i0, ..., :] = irfft(result2[i0, ..., :])。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量 | (C1)、(C2)、(C4)、(C5) |
| (I2) | fft_type |
FFT、IFFT、RFFT 和 IRFFT 的枚举 |
(C2)、(C5) |
| (I3) | fft_length |
类型为 si64 的 1 维张量常量 |
(C1)、(C3)、(C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量 | (C2)、(C4)、(C5) |
限制条件
- (C1)
size(fft_length) <= rank(operand)。 - (C2)
operand和result元素类型之间的关系各不相同:- 如果
fft_type = FFT、element_type(operand)和element_type(result)具有相同的复杂类型。 - 如果
fft_type = IFFT、element_type(operand)和element_type(result)具有相同的复杂类型。 - 如果
fft_type = RFFT,element_type(operand)是浮点类型,而element_type(result)是具有相同浮点语义的复数类型。 - 如果
fft_type = IRFFT,则element_type(operand)为复杂类型,而element_type(result)为具有相同浮点语义的浮点类型。
- 如果
- (C3)
1 <= size(fft_length) <= 3。 - (C4) 如果在
operand和result中存在浮点类型的张量real,则shape(real)[-size(fft_length):] = fft_length。 - (C5)
shape(result) = shape(operand)但以下情况除外:- 如果
fft_type = RFFT,则为dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1。 - 如果
fft_type = IRFFT,则为dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1。
- 如果
示例
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
向下取整
语义
对 operand 张量执行逐元素向下取整运算,并生成 result 张量。
实现 IEEE-754 规范中的 roundToIntegralTowardNegative 操作。对于量化类型,执行 dequantize_op_quantize(floor, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点类型张量或逐张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
gather
语义
从 operand 张量中收集偏移量在 start_indices 中指定的切片,并生成 result 张量。
下图通过一个具体示例展示了 result 中的元素如何映射到 operand 中的元素。该图表选择了一些示例 result 指数,并详细说明了它们对应的 operand 指数。
更正式地说,result[result_index] = operand[operand_index],其中:
batch_dims = [d for d in axes(result) and d not in offset_dims]。batch_index = result_index[batch_dims...]。start_index的定义如下:start_indices[bi0, ..., :, ..., biN],其中bi是batch_index中的各个元素,如果index_vector_dim<rank(start_indices),则将:插入到索引index_vector_dim处。- 否则为
[start_indices[batch_index]]。
- 对于
axes(operand)中的d_operand,- 如果
d_operand = start_index_map[d_start],则为full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])。 - 否则为
full_start_index[d_operand] = 0。
- 如果
- 对于
axes(operand)中的d_operand,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)](如果d_operand = operand_batching_dims[i_batching]和d_start = start_indices_batching_dims[i_batching])。- 否则为
full_batching_index[d_operand] = 0。
offset_index = result_index[offset_dims...]。full_offset_index = [oi0, ..., 0, ..., oiN],其中oi是offset_index中的各个元素,而0插入到从collapsed_slice_dims到operand_batching_dims的索引处。operand_index = full_start_index + full_batching_index + full_offset_index。
如果 indices_are_sorted 为 true,则实现可以假定 start_indices 是按 start_index_map 排序的,否则行为未定义。更正式地说,对于所有 i1 < i2,从 indices(result) 到 full_start_index(i1) <= full_start_index(i2)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C8)、(C11)、(C17)、(C19-C21)、(C23) |
| (I2) | start_indices |
整数类型的张量 | (C2-C3)、(C14)、(C17)、(C22) |
| (I3) | offset_dims |
类型为 si64 的 1 维张量常量 |
(C1)、(C4-C5)、(C22) |
| (I4) | collapsed_slice_dims |
类型为 si64 的 1 维张量常量 |
(C1)、(C6-C9)、(C22) |
| (I5) | operand_batching_dims |
类型为 si64 的 1 维张量常量 |
(C1)、(C6)、(C10-C12)、(C16-C18)、(C22) |
| (I6) | start_indices_batching_dims |
类型为 si64 的 1 维张量常量 |
(C13-C17) |
| (I7) | start_index_map |
类型为 si64 的 1 维张量常量 |
(C3)、(C18-C19) |
| (I8) | index_vector_dim |
类型为 si64 的常量 |
(C2-C3)、(C15)、(C22) |
| (I9) | slice_sizes |
类型为 si64 的 1 维张量常量 |
(C9)、(C12)、(C20-C22) |
| (I10) | indices_are_sorted |
类型为 i1 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C5)、(C22-C23) |
限制条件
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)。 - (C2)
0 <= index_vector_dim <= rank(start_indices)。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)。 - (C5)
0 <= offset_dims < rank(result)。 - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims)。 - (C8)
0 <= collapsed_slice_dims < rank(operand)。 - (C9)
slice_sizes[collapsed_slice_dims...] <= 1。 - (C10)
is_sorted(operand_batching_dims)。 - (C11)
0 <= operand_batching_dims < rank(operand)。 - (C12)
slice_sizes[operand_batching_dims...] <= 1。 - (C13)
is_unique(start_indices_batching_dims)。 - (C14)
0 <= start_indices_batching_dims < rank(start_indices)。 - (C15)
index_vector_dim not in start_indices_batching_dims。 - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)。 - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)。 - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))。 - (C19)
0 <= start_index_map < rank(operand)。 - (C20)
size(slice_sizes) = rank(operand)。 - (C21)
0 <= slice_sizes <= shape(operand)。 - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)其中:batch_dim_sizes = shape(start_indices),但其中不包含与index_vector_dim对应的start_indices的维度大小。offset_dim_sizes = slice_sizes,但slice_sizes中与collapsed_slice_dims和operand_batching_dims对应的维度大小不包含在内。combine将batch_dim_sizes放置在与batch_dims对应的轴上,并将offset_dim_sizes放置在与offset_dims对应的轴上。
- (C23)
element_type(operand) = element_type(result)。
示例
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
语义
生成给定 operand 的 dimension 的大小。更正式地说,result = dim(operand, dimension)。语义仅与类型的形状组件有关。元素类型可以是任何类型。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1) |
| (I2) | dimension |
类型为 si64 的常量 |
(C1) |
输出
| 名称 | 类型 |
|---|---|
result |
类型为 si32 的 0 维张量 |
限制条件
- (C1)
0 <= dimension < rank(operand)。
示例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
语义
提取 operand 元组中位置 index 处的元素,并生成 result。更正式地说,result = operand[index]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
tuple | (C1)、(C2) |
| (I2) | index |
类型为 si32 的常量 |
(C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
任意值 | (C2) |
限制条件
- (C1)
0 <= index < size(operand)。 - (C2)
type(result) = tuple_element_types(operand)[index]。
示例
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) <{index = 0 : i32}> : (tuple<tensor<2xf64>, tuple<tensor<i64>>>) -> tensor<2xf64>
// %result: [1.0, 2.0]
if
语义
根据 pred 的值,执行 true_branch 或 false_branch 中的一个函数,并生成输出。更正式地说,result =
pred ? true_branch() : false_branch()。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | pred |
类型为 i1 的 0 维张量 |
|
| (I2) | true_branch |
函数 | (C1-C3) |
| (I3) | false_branch |
函数 | (C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量、量化张量或令牌 | (C3) |
限制条件
- (C1)
input_types(true_branch) = input_types(false_branch) = []。 - (C2)
output_types(true_branch) = output_types(false_branch)。 - (C3)
type(results...) = output_types(true_branch)。
示例
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
语义
按元素方式从 operand 中提取虚部,并生成 result 张量。更正式地说,对于每个元素 x:imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量 | (C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点型张量 | (C1)、(C2) |
限制条件
- (C1)
shape(result) = shape(operand)。 - (C2)
element_type(result)的定义如下:- 如果
is_complex(operand),则为complex_element_type(element_type(operand))。 - 否则为
element_type(operand)。
- 如果
示例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
infeed
语义
从输入源读取数据并生成 results。
infeed_config 的语义是由实现定义的。
results 由首先出现的载荷值和最后出现的令牌组成。未来,我们计划将载荷和令牌拆分为两个单独的输出,以提高清晰度 (#670)。
输入
| 标签 | 名称 | 类型 |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
类型为 string 的常量 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量、量化张量或令牌 | (C1-C3) |
限制条件
- (C1)
0 < size(results)。 - (C2)
is_empty(result[:-1])或is_tensor(type(results[:-1]))。 - (C3)
is_token(type(results[-1]))。
示例
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
语义
沿 iota_dimension 维度,用从零开始按升序排列的值填充 output 张量。更正式地说,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
output |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
0 <= iota_dimension < rank(output)。
示例
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
语义
逐元素检查 x 中的值是否为有限值(即既不是 +Inf,也不是 -Inf 或 NaN),并生成 y 张量。实现 IEEE-754 规范中的 isFinite 操作。对于量化类型,结果始终为 true。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | x |
浮点类型张量或逐张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
y |
布尔值类型的张量 | (C1) |
限制条件
- (C1)
shape(x) = shape(y)。
示例
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
语义
对 operand 张量执行元素级对数运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
log。 - 对于复数:复对数。
- 对于量化类型:
dequantize_op_quantize(log, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
语义
对 operand 张量执行按元素取对数加 1 操作,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
logp1。 - 对于复数:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - 对于量化类型:
dequantize_op_quantize(log_plus_one, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistic
语义
对 operand 张量执行元素级逻辑运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
division(1, addition(1, exp(-x)))。 - 对于复数:复数逻辑。
- 对于量化类型:
dequantize_op_quantize(logistic, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
地图
语义
沿 dimensions 对 inputs 应用映射函数 computation,并生成 result 张量。
更正式地说,result[result_index] = computation(inputs...[result_index])。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1-C4) |
| (I2) | dimensions |
类型为 si64 的 1 维张量常量 |
(C3) |
| (I3) | computation |
函数 | (C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1)、(C4) |
限制条件
- (C1)
shape(inputs...) = shape(result)。 - (C2)
0 < size(inputs) = N。 - (C3)
dimensions = range(rank(inputs[0]))。 - (C4)
computation的类型为(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>,其中Ei = element_type(inputs[i])和E' = element_type(result)。
示例
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
最大值
语义
对张量 lhs 和 rhs 执行元素级最大值运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑或。
- 对于整数:整数最大值。
- 对于浮点数:来自 IEEE-754 的
maximum。 - 对于复数:
(real, imaginary)对的字典序最大值。 对复数施加排序会涉及令人惊讶的语义,因此我们计划在未来移除对该操作的复数支持 (#560)。 - 对于量化类型:
dequantize_op_quantize(maximum, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1) |
| (I2) | rhs |
张量或按张量量化的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
示例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
最小值
语义
对张量 lhs 和 rhs 执行逐元素最小值运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑 AND。
- 对于整数:整数最小值。
- 对于浮点数:来自 IEEE-754 的
minimum。 - 对于复数:
(real, imaginary)对的字典序最小值。 对复数施加排序会涉及令人惊讶的语义,因此我们计划在未来移除对该操作的复数支持 (#560)。 - 对于量化类型:
dequantize_op_quantize(minimum, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1) |
| (I2) | rhs |
张量或按张量量化的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
示例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
相乘
语义
对两个张量 lhs 和 rhs 执行逐元素乘积运算,并生成一个 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑 AND。
- 对于整数:整数乘法。
- 对于浮点数:来自 IEEE-754 的
multiplication。 - 对于复数:复数乘法。
- 对于量化类型:
dequantize_op_quantize(multiply, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
张量或按张量量化的张量 | (C1) |
| (I2) | rhs |
张量或按张量量化的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
语义
对 operand 张量执行按元素求反运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于带符号整数:整数取反。
- 对于无符号整数:位转换为有符号整数、整数求反、位转换回无符号整数。
- 对于浮点数:来自 IEEE-754 的
negate。 - 对于复数:复数求反。
- 对于量化类型:
dequantize_op_quantize(negate, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
非
语义
对张量 operand 执行按元素 NOT 运算,并生成 result 张量。
根据元素类型,执行以下操作:
- 对于布尔值:逻辑非。
- 对于整数:按位 NOT。
参数
| 名称 | 类型 | 限制条件 |
|---|---|---|
operand |
布尔值或整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
布尔值或整数类型的张量 | (C1) |
限制条件
- (C1)
type(operand) = type(result)。
示例
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
语义
确保生成 operand 的操作在任何依赖于 result 的操作之前执行,并防止编译器转换将操作移过屏障。除此之外,该运算是恒等运算,即 result = operand。
参数
| 名称 | 类型 | 限制条件 |
|---|---|---|
operand |
数量可变的张量、每个张量的量化张量或 token | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
数量可变的张量、每个张量的量化张量或 token | (C1) |
限制条件
- (C1)
type(operand...) = type(result...)。
示例
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
或
语义
对两个张量 lhs 和 rhs 执行按元素或运算,并生成一个 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑或。
- 对于整数:按位或。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数或布尔值类型的张量 | (C1) |
| (I2) | rhs |
整数或布尔值类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数或布尔值类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
出料
语义
将 inputs 写入出馈,并生成 result 令牌。
outfeed_config 的语义是由实现定义的。
输入
| 标签 | 名称 | 类型 |
|---|---|---|
| (I1) | inputs |
数量可变的张量或量化张量 |
| (I2) | token |
token |
| (I3) | outfeed_config |
类型为 string 的常量 |
输出
| 名称 | 类型 |
|---|---|
result |
token |
示例
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
语义
通过在张量的周围以及张量的元素之间填充给定的 padding_value 来扩展 operand。
edge_padding_low 和 edge_padding_high 分别指定在每个维度的低端(紧邻索引 0)和高端(紧邻最高索引)添加的填充量。填充量可以是负值,其中负填充的绝对值表示要从指定维度中移除的元素数量。
interior_padding 指定在每个维度中任意两个元素之间添加的内边距量,该值不得为负。内边距在边缘边距之前出现,因此负边缘边距将从内边距操作数中移除元素。
更正式地说,result[result_index] 的定义如下:
- 如果
result_index = edge_padding_low + operand_index * (interior_padding + 1),则为operand[operand_index]。 - 否则为
padding_value。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C2)、(C4) |
| (I2) | padding_value |
0 维张量或每个张量的量化张量 | (C1) |
| (I3) | edge_padding_low |
类型为 si64 的 1 维张量常量 |
(C1)、(C4) |
| (I4) | edge_padding_high |
类型为 si64 的 1 维张量常量 |
(C1)、(C4) |
| (I5) | interior_padding |
类型为 si64 的 1 维张量常量 |
(C2-C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C3-C6) |
限制条件
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。 - (C3)
0 <= interior_padding。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
示例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
语义
生成当前进程的 partition_id。
输出
| 名称 | 类型 |
|---|---|
result |
类型为 ui32 的 0 维张量 |
示例
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
语义
按元素计算 operand 张量中设置的位数,并生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数类型的张量 | (C1) |
限制条件
- (C1)
type(operand) = type(result)。
示例
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
幂数
语义
对 lhs 张量和 rhs 张量执行逐元素指数运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于整数:整数指数。
- 对于浮点数:来自 IEEE-754 的
pow。 - 对于复数:复数指数。
- 对于量化类型:
dequantize_op_quantize(power, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
| (I2) | rhs |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
语义
从 operand 中按元素提取实部,并生成 result 张量。更正式地说,对于每个元素 x:real(x) = is_complex(x) ? real_part(x) : x。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量 | (C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点型张量 | (C1)、(C2) |
限制条件
- (C1)
shape(result) = shape(operand)。 - (C2)
element_type(result)的定义如下:- 如果
is_complex(operand),则为complex_element_type(element_type(operand))。 - 否则为
element_type(operand)。
- 如果
示例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
语义
从具有 channel_id 的渠道接收数据,并生成 results。
如果 is_host_transfer 为 true,则该操作会从主机转移数据。否则,它会根据 source_target_pairs 的值从其他设备转移数据。此标志会重复提供 channel_type 中提供的信息,因此我们计划将来只保留其中一个 (#666)。如果 is_host_transfer = false 且 source_target_pairs 为 None 或空,则视为未定义行为。
results 由首先出现的载荷值和最后出现的令牌组成。未来,我们计划将载荷和令牌拆分为两个单独的输出,以提高清晰度 (#670)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
类型为 si64 的常量 |
|
| (I3) | channel_type |
DEVICE_TO_DEVICE 和 DEVICE_TO_HOST 的枚举 |
(C5) |
| (I4) | is_host_transfer |
类型为 i1 的常量 |
(C5-C6) |
| (I5) | source_target_pairs |
类型为 si64 的二维张量常量 |
(C1-C4)、(C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量、量化张量或令牌 | (C2-C4) |
限制条件
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N,其中N定义为:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_partition,则为num_partitions。
- 如果使用
- (C5)
channel_type的定义如下:- 如果
is_host_transfer = true,则为DEVICE_TO_HOST, - 否则为
DEVICE_TO_DEVICE。
- 如果
示例
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
语义
沿 dimensions 对 inputs 和 init_values 应用归约函数 body,并生成 results 张量。
缩减的顺序由实现定义,这意味着 body 和 init_values 必须形成幺半群,才能保证该运算在所有实现中针对所有输入产生相同的结果。不过,许多热门的归约并不满足此条件。例如,body 的浮点数加法和 init_values 的零实际上并不构成幺半群,因为浮点数加法不遵守结合律。
更正式地说,results...[j0, ..., jR-1] = reduce(input_slices_converted),其中:
input_slices = inputs...[j0, ..., :, ..., jR-1],其中:插入在dimensions。input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)。init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)。- 对于某个二叉树
reduce(input_slices_converted) = exec(schedule),schedule其中:exec(node) = body(exec(node.left), exec(node.right))。exec(leaf) = leaf.value。
schedule是一个由实现定义的完整二叉树,其按顺序遍历包括:input_slices_converted...[index]值,对于index_space(input_slices_converted)中的所有index,按index的字典升序排列。- 在实现定义的位置穿插实现定义的
init_values_converted数量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1-C4)、(C6)、(C7) |
| (I2) | init_values |
数量可变的 0 维张量或每个张量的量化张量 | (C2)、(C3) |
| (I3) | dimensions |
类型为 si64 的 1 维张量常量 |
(C4)、(C5)、(C7) |
| (I4) | body |
函数 | (C6) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C3)、(C7)、(C8) |
限制条件
- (C1)
same(shape(inputs...))。 - (C2)
element_type(inputs...) = element_type(init_values...)。 - (C3)
0 < size(inputs) = size(init_values) = size(results) = N。 - (C4)
0 <= dimensions < rank(inputs[0])。 - (C5)
is_unique(dimensions)。 - (C6)
body的类型为(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中is_promotable(element_type(inputs[i]), Ei)。 - (C7)
shape(results...) = shape(inputs...),但与dimensions对应的inputs...的维度大小不包括在内。 - (C8) 对于
[0,N)中的所有i,element_type(results[i]) = Ei。
示例
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
语义
将 operand 逐元素转换为使用 exponent_bits 和 mantissa_bits 的另一种浮点类型,然后再转换回原始浮点类型,并生成 output 张量。
更正式的说法:
- 更新原始值的尾数位,以使用
roundToIntegralTiesToEven语义将原始值舍入为最接近的mantissa_bits可表示值。 - 然后,如果
mantissa_bits小于原始值的尾数位数,则将尾数位数截断为mantissa_bits。 - 然后,如果中间结果的指数位不适合
exponent_bits提供的范围,则中间结果会使用原始符号溢出到无穷大或使用原始符号下溢到零。 - 对于量化类型,执行
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
| (I2) | exponent_bits |
类型为 si32 的常量 |
(C2) |
| (I3) | mantissa_bits |
类型为 si32 的常量 |
(C3) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
output |
浮点类型张量或逐张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(output)。 - (C2)
1 <= exponent_bits。 - (C3)
0 <= mantissa_bits。
示例
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
语义
在 StableHLO 进程网格中的每个进程组内,使用 computations 对每个进程的 operand 张量的值执行缩减,沿 scatter_dimension 将缩减结果拆分为多个部分,并在进程之间分散拆分后的部分,以生成 result。
该操作将 StableHLO 进程网格拆分为 process_groups,定义如下:
- 如果
channel_id <= 0 and use_global_device_ids = false,则为cross_replica(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = false,则为cross_replica_and_partition(replica_groups)。 - 如果
channel_id > 0 and use_global_device_ids = true,则为flattened_ids(replica_groups)。
之后,在每个 process_group 中:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)。parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)。result@receiver = parts@sender[receiver_index],对于process_group中的所有sender,其中receiver_index = process_group.index(receiver)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C2)、(C7)、(C8) |
| (I2) | scatter_dimension |
类型为 si64 的常量 |
(C1)、(C2)、(C8) |
| (I3) | replica_groups |
类型为 si64 的二维张量常量 |
(C3-C5) |
| (I4) | channel_id |
类型为 si64 的常量 |
(C6) |
| (I5) | use_global_device_ids |
类型为 i1 的常量 |
(C6) |
| (I6) | computation |
函数 | (C7) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C8-C9) |
限制条件
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0。 - (C2)
0 <= scatter_dimension < rank(operand)。 - (C3)
is_unique(replica_groups)。 - (C4)
size(replica_groups)的定义如下:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_replica_and_partition,则为num_replicas。 - 如果使用
flattened_ids,则为num_processes。
- 如果使用
- (C5)
0 <= replica_groups < size(replica_groups)。 - (C6) 如果
use_global_device_ids = true,则channel_id > 0。 - (C7)
computation的类型为(tensor<E>, tensor<E>) -> (tensor<E>),其中is_promotable(element_type(operand), E)。 - (C8)
shape(result) = shape(operand)除以下情况外:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)。
- (C9)
element_type(result) = E。
示例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
语义
将缩减函数 body 应用于大小为 inputs 和 init_values 的窗口,并生成 results。
下图通过一个具体示例展示了如何根据 inputs... 计算 results... 中的元素。
更正式地说,results...[result_index] = reduce(windows, init_values, axes(inputs...), body)(请参阅 reduce),其中:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)。window_start = result_index * window_strides。window_end = window_start + (window_dimensions - 1) * window_dilations + 1。windows = slice(padded_inputs..., window_start, window_end, window_dilations)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15) |
| (I2) | init_values |
数量可变的 0 维张量或每个张量的量化张量 | (C1)、(C13) |
| (I3) | window_dimensions |
类型为 si64 的 1 维张量常量 |
(C4)、(C5)、(C15) |
| (I4) | window_strides |
类型为 si64 的 1 维张量常量 |
(C6)、(C7)、(C15) |
| (I5) | base_dilations |
类型为 si64 的 1 维张量常量 |
(C8)、(C9)、(C15) |
| (I6) | window_dilations |
类型为 si64 的 1 维张量常量 |
(C10)、(C11)、(C15) |
| (I7) | padding |
类型为 si64 的二维张量常量 |
(C12)、(C15) |
| (I8) | body |
函数 | (C13) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C1)、(C14-C16) |
限制条件
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N。 - (C2)
same(shape(inputs...))。 - (C3)
element_type(inputs...) = element_type(init_values...)。 - (C4)
size(window_dimensions) = rank(inputs[0])。 - (C5)
0 < window_dimensions。 - (C6)
size(window_strides) = rank(inputs[0])。 - (C7)
0 < window_strides。 - (C8)
size(base_dilations) = rank(inputs[0])。 - (C9)
0 < base_dilations。 - (C10)
size(window_dilations) = rank(inputs[0])。 - (C11)
0 < window_dilations。 - (C12)
shape(padding) = [rank(inputs[0]), 2]。 - (C13)
body的类型为(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中is_promotable(element_type(inputs[i]), Ei)。 - (C14)
same(shape(results...))。 - (C15)
shape(results[0]) = num_windows其中:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1。padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]。dilated_window_shape = (window_dimensions - 1) * window_dilations + 1。is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape。num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1。
- (C16)
element_type(results[i]) = Ei适用于[0,N)中的所有i。
示例
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
余数
语义
按元素计算被除数 lhs 和除数 rhs 张量的余数,并生成 result 张量。
更正式地说,结果的符号取自被除数,结果的绝对值始终小于除数的绝对值。余数计算为 lhs - d * rhs,其中 d 由以下公式给出:
- 对于整数:
stablehlo.divide(lhs, rhs)。 - 对于浮点数:来自 IEEE-754 的
division(lhs, rhs),带有舍入属性roundTowardZero。 - 对于复数:待定 (#997)。
- 对于量化类型:
dequantize_op_quantize(remainder, lhs, rhs, type(result))。
对于浮点元素类型,此运算与 IEEE-754 规范中的 remainder 运算相反,其中 d 是最接近 lhs/rhs 精确值的整数值,如果出现平局,则取偶数。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数、浮点或复数类型的张量或按张量量化的张量 | (C1) |
| (I2) | rhs |
整数、浮点或复数类型的张量或按张量量化的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或按张量量化的张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
语义
生成当前进程的 replica_id。
输出
| 名称 | 类型 |
|---|---|
result |
类型为 ui32 的 0 维张量 |
示例
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
语义
将 operand 张量重塑为 result 张量。从概念上讲,这相当于保持相同的规范表示形式,但可能会更改形状,例如从 tensor<2x3xf32> 更改为 tensor<3x2xf32> 或 tensor<6xf32>。
更正式地说,result[result_index] = operand[operand_index],其中 result_index 和 operand_index 在 index_space(result) 和 index_space(operand) 的字典序中具有相同的位置。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C3) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1-C3) |
限制条件
- (C1)
element_type(result)由以下人员提供:element_type(operand)(如果!is_per_axis_quantized(operand))。element_type(operand),但quantization_dimension(operand)和quantization_dimension(result)可能有所不同,其他方面则相同。
- (C2)
size(operand) = size(result)。 - (C3) 如果
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))。reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
示例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
语义
沿指定 dimensions 反转 operand 中元素的顺序,并生成 result 张量。更正式地说,result[result_index] = operand[operand_index],其中:
operand_index[d] = dim(result, d) - result_index[d] - 1如果dimensions中存在d。- 否则为
operand_index[d] = result_index[d]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1)、(C3) |
| (I2) | dimensions |
类型为 si64 的 1 维张量常量 |
(C2)、(C3) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1)、(C3) |
限制条件
- (C1)
type(operand) = type(result)。 - (C2)
is_unique(dimensions)。 - (C3)
0 <= dimensions < rank(result)。
示例
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
语义
使用 rng_distribution 算法生成随机数,并生成指定形状 shape 的 result 张量。
如果值为 rng_distribution = UNIFORM,则生成的随机数遵循区间 [a, b) 内的均匀分布。如果值为 a >= b,则行为未定义。
如果值为 rng_distribution = NORMAL,则生成的随机数遵循正态分布,均值为 a,标准差为 b。如果值为 b < 0,则行为未定义。
随机数的具体生成方式由实现定义。例如,它们可能具有确定性,也可能不具有确定性;它们可能使用隐藏状态,也可能不使用隐藏状态。
在与许多利益相关方的对话中,此操作已被视为已弃用,因此我们计划在未来探索移除此操作 (#597)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | a |
整数、布尔值或浮点值类型的 0 维张量 | (C1)、(C2) |
| (I2) | b |
整数、布尔值或浮点值类型的 0 维张量 | (C1)、(C2) |
| (I3) | shape |
类型为 si64 的 1 维张量常量 |
(C3) |
| (I4) | rng_distribution |
UNIFORM 和 NORMAL 的枚举 |
(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、布尔值或浮点值类型的张量 | (C1-C3) |
限制条件
- (C1)
element_type(a) = element_type(b) = element_type(result)。 - (C2) 如果
rng_distribution = NORMAL,则is_float(a)。 - (C3)
shape(result) = shape。
示例
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
语义
返回一个填充了均匀随机位的 output 和一个使用伪随机数生成器算法 rng_algorithm(给定初始状态 initial_state)更新的输出状态 output_state。输出保证是 initial_state 的确定性函数,但不保证在不同实现之间是确定性的。
rng_algorithm 是以下值之一:
DEFAULT:由实现定义的算法。THREE_FRY:Threefry 算法的实现定义变体。*PHILOX:Philox 算法的实现定义变体。*
* 参见:Salmon 等人,SC 2011。并行随机数:只需三步即可完成。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | rng_algorithm |
DEFAULT、THREE_FRY 和 PHILOX 的枚举 |
(C2) |
| (I2) | initial_state |
类型为 ui64 的一维张量 |
(C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
output_state |
类型为 ui64 的一维张量 |
(C1) |
output |
整数或浮点类型的张量 |
限制条件
- (C1)
type(initial_state) = type(output_state)。 - (C2)
size(initial_state)的定义如下:- 如果为
rng_algorithm = DEFAULT,则由实现定义。 - 如果
rng_algorithm = THREE_FRY,则为2。 - 如果
rng_algorithm = PHILOX,则为2或3。
- 如果为
示例
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
语义
对 operand 张量执行按元素舍入到最接近的整数运算,舍入时远离零,并生成 result 张量。实现 IEEE-754 规范中的 roundToIntegralTiesToAway 操作。对于量化类型,执行 dequantize_op_quantize(round_nearest_afz, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点类型张量或逐张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
语义
对 operand 张量执行按元素舍入到最接近的整数运算,如果出现平局,则舍入到偶数,并生成 result 张量。实现 IEEE-754 规范中的 roundToIntegralTiesToEven 操作。对于量化类型,执行 dequantize_op_quantize(round_nearest_even, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点类型张量或逐张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点类型张量或逐张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
语义
对 operand 张量执行逐元素倒数平方根运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
rSqrt。 - 对于复数:复数倒数平方根。
- 对于量化类型:
dequantize_op_quantize(rsqrt, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
语义
生成 results 张量,该张量与 inputs 张量相同,只是使用 update_computation 将 scatter_indices 指定的几个切片更新为值 updates。
下图通过一个具体示例展示了 updates... 中的元素如何映射到 results... 中的元素。该图表选择了一些示例 updates... 指数,并详细说明了它们对应的 results... 指数。
更正式地说,对于 index_space(updates[0]) 中的所有 update_index:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]。update_scatter_index = update_index[update_scatter_dims...]。start_index的定义如下:scatter_indices[si0, ..., :, ..., siN],其中si是update_scatter_index中的各个元素,如果index_vector_dim<rank(scatter_indices),则将:插入到索引index_vector_dim处。- 否则为
[scatter_indices[update_scatter_index]]。
- 对于
axes(inputs[0])中的d_input,- 如果
d_input = scatter_dims_to_operand_dims[d_start],则为full_start_index[d_input] = start_index[d_start]。 - 否则为
full_start_index[d_input] = 0。
- 如果
- 对于
axes(inputs[0])中的d_input,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)](如果d_input = input_batching_dims[i_batching]和d_start = scatter_indices_batching_dims[i_batching])。- 否则为
full_batching_index[d_input] = 0。
update_window_index = update_index[update_window_dims...]。full_window_index = [wi0, ..., 0, ..., wiN],其中wi是update_window_index中的各个元素,而0插入到从inserted_window_dims到input_batching_dims的索引处。result_index = full_start_index + full_batching_index + full_window_index。
鉴于此,results = exec(schedule, inputs),其中:
schedule是index_space(updates[0])的实现定义的排列。exec([update_index, ...], results) = exec([...], updated_results)其中:- 如果
result_index在shape(results...)的范围内 updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_results是results的副本,其中results...[result_index]设置为updated_values...。- 否则
updated_results = results。
- 如果
exec([], results) = results。
如果 indices_are_sorted 为 true,则实现可以假定 scatter_indices 是按 scatter_dims_to_operand_dims 排序的,否则行为未定义。更正式地说,对于 indices(result) 中的所有 i1 < i2,full_start_index(i1) <= full_start_index(i2)。
如果 unique_indices 为 true,则实现可以假定要分散到的所有 result_index 索引都是唯一的。如果 unique_indices 为 true,但要分散到的索引不唯一,则行为未定义。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1)、(C2)、(C4-C6)、(C11)、(C13)、(C18)、(C21)、(C23-C24) |
| (I2) | scatter_indices |
整数类型的张量 | (C4)、(C15)、(C19)、(C22) |
| (I3) | updates |
数量可变的张量或每个张量的量化张量 | (C3-C6)、(C8) |
| (I4) | update_window_dims |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C7-C8) |
| (I5) | inserted_window_dims |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C9-C11) |
| (I6) | input_batching_dims |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20) |
| (I7) | scatter_indices_batching_dims |
类型为 si64 的 1 维张量常量 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
类型为 si64 的 1 维张量常量 |
(C19-C21) |
| (I9) | index_vector_dim |
类型为 si64 的常量 |
(C4)、(C16)、(C19)、(C22) |
| (I10) | indices_are_sorted |
类型为 i1 的常量 |
|
| (I11) | unique_indices |
类型为 i1 的常量 |
|
| (I12) | update_computation |
函数 | (C23) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C24-C25) |
限制条件
- (C1)
same(shape(inputs...))。 - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims)。 - (C3)
same(shape(updates...))。 - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)其中:update_scatter_dim_sizes = shape(scatter_indices),但其中不包含与index_vector_dim对应的scatter_indices的维度大小。update_window_dim_sizes <= shape(inputs[0]),但其中不包含inputs[0]中与inserted_window_dims和input_batching_dims对应的维度大小。combine将update_scatter_dim_sizes放置在与update_scatter_dims对应的轴上,并将update_window_dim_sizes放置在与update_window_dims对应的轴上。
- (C5)
0 < size(inputs) = size(updates) = N。 - (C6)
element_type(updates...) = element_type(inputs...)。 - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)。 - (C8)
0 <= update_window_dims < rank(updates[0])。 - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims)。 - (C11)
0 <= inserted_window_dims < rank(inputs[0])。 - (C12)
is_sorted(input_batching_dims)。 - (C13)
0 <= input_batching_dims < rank(inputs[0]))。 - (C14)
is_unique(scatter_indices_batching_dims)。 - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)。 - (C16)
index_vector_dim not in scatter_indices_batching_dims。 - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)。 - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)。 - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1。 - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))。 - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])。 - (C22)
0 <= index_vector_dim <= rank(scatter_indices)。 - (C23)
update_computation的类型为(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中is_promotable(element_type(inputs[i]), Ei)。 - (C24)
shape(inputs...) = shape(results...)。 - (C25)
element_type(results[i]) = Ei适用于[0,N)中的所有i。
示例
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
选择
语义
生成一个 result 张量,其中每个元素都是根据 pred 的相应元素的值从 on_true 或 on_false 张量中选择的。
更正式地说,result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index],其中 pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]。对于量化类型,执行 dequantize_select_quantize(pred, on_true, on_false, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | pred |
类型为 i1 的张量 |
(C1) |
| (I2) | on_true |
张量或按张量量化的张量 | (C1-C2) |
| (I3) | on_false |
张量或按张量量化的张量 | (C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C2) |
限制条件
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)。 - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)。
示例
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
语义
使用 select 根据 input 张量的 reduce_window 的结果,使用 scatter 将 source 张量中的值分散,并生成 result 张量。
下图通过一个具体示例展示了如何根据 operand 和 source 计算 result 中的元素。
更正式的说法:
selected_values = reduce_window_without_init(...),并使用以下输入内容:inputs = [operand].window_dimensions、window_strides和padding,这些参数会按原样使用。base_dilations = windows_dilations = 1。body的定义如下:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;其中,
E = element_type(operand)和reduce_window_without_init的工作方式与reduce_window完全相同,只不过底层reduce(请参阅 reduce)的schedule不包含初始值。目前尚未指定在相应窗口没有值时会发生什么情况 (#731)。result[result_index] = reduce([source_values], [init_value], [0], scatter)其中:source_values = [source[source_index] for source_index in source_indices]。- 如果
selected_values[source_index]具有来自operand_index的operand元素,则为selected_index(source_index) = operand_index。 source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1-C4)、(C6)、(C8-C11) |
| (I2) | source |
张量或按张量量化的张量 | (C1)、(C2) |
| (I3) | init_value |
0 维张量或每个张量的量化张量 | (C3) |
| (I4) | window_dimensions |
类型为 si64 的 1 维张量常量 |
(C2)、(C4)、(C5) |
| (I5) | window_strides |
类型为 si64 的 1 维张量常量 |
(C2)、(C6)、(C7) |
| (I6) | padding |
类型为 si64 的二维张量常量 |
(C2)、(C8) |
| (I7) | select |
函数 | (C9) |
| (I8) | scatter |
函数 | (C10) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C11-C12) |
限制条件
- (C1)
element_type(operand) = element_type(source)。 - (C2)
shape(source) = num_windows其中:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]。is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape。num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1。
- (C3)
element_type(init_value) = element_type(operand)。 - (C4)
size(window_dimensions) = rank(operand)。 - (C5)
0 < window_dimensions。 - (C6)
size(window_strides) = rank(operand)。 - (C7)
0 < window_strides。 - (C8)
shape(padding) = [rank(operand), 2]。 - (C9)
select的类型为(tensor<E>, tensor<E>) -> tensor<i1>,其中E = element_type(operand)。 - (C10)
scatter的类型为(tensor<E>, tensor<E>) -> tensor<E>,其中is_promotable(element_type(operand), E)。 - (C11)
shape(operand) = shape(result)。 - (C12)
element_type(result) = E。
示例
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
发送
语义
向频道 channel_id 发送 inputs。然后,输入会按 source_target_pairs 指定的顺序发送到其他设备。该操作会生成 result 令牌。
如果 is_host_transfer 为 true,则操作会将数据传输到主机。否则,它会根据 source_target_pairs 的值将数据传输到另一部设备。此标志会重复提供 channel_type 中提供的信息,因此我们计划将来只保留其中一个 (#666)。如果 is_host_transfer = false 且 source_target_pairs 为 None 或空,则视为未定义行为。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或量化张量 | |
| (I2) | token |
token |
|
| (I3) | channel_id |
类型为 si64 的常量 |
|
| (I4) | channel_type |
DEVICE_TO_DEVICE 和 DEVICE_TO_HOST 的枚举 |
(C5) |
| (I5) | is_host_transfer |
类型为 i1 的常量 |
(C5-C6) |
| (I6) | source_target_pairs |
类型为 si64 的二维张量常量 |
(C1-C4)、(C6) |
输出
| 名称 | 类型 |
|---|---|
result |
token |
限制条件
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N,其中N定义为:- 如果使用
cross_replica,则为num_replicas。 - 如果使用
cross_partition,则为num_partitions。
- 如果使用
- (C5)
channel_type的定义如下:- 如果
is_host_transfer = true,则为DEVICE_TO_HOST, - 否则为
DEVICE_TO_DEVICE。
- 如果
示例
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
语义
对 lhs 张量执行元素级左移运算,移动 rhs 位,并生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数类型的张量 | (C1) |
| (I2) | rhs |
整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
语义
对 lhs 张量执行按元素的算术右移运算,移动位数为 rhs,并生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数类型的张量 | (C1) |
| (I2) | rhs |
整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
语义
对 lhs 张量执行按位逻辑右移运算,移动位数为 rhs,并生成 result 张量。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数类型的张量 | (C1) |
| (I2) | rhs |
整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
签名
语义
返回 operand 的逐元素符号,并生成 result 张量。
更正式地说,对于每个元素 x,可以使用 Python 语法按如下方式表达其语义:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
对于量化类型,执行 dequantize_op_quantize(sign, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
有符号整数、浮点或复数类型的张量或每张量量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
有符号整数、浮点或复数类型的张量或每张量量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
正弦
语义
对 operand 张量执行逐元素正弦运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
sin。 - 对于复数:复正弦。
- 对于量化类型:
dequantize_op_quantize(sine, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
语义
使用静态计算的起始索引从 operand 中提取切片,并生成 result 张量。start_indices 包含每个维度的切片的起始索引,limit_indices 包含每个维度的切片的结束索引(不含),strides 包含每个维度的步幅。
更正式地说,result[result_index] = operand[operand_index],其中 operand_index = start_indices + result_index * strides。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或按张量量化的张量 | (C1-C3)、(C5) |
| (I2) | start_indices |
类型为 si64 的 1 维张量常量 |
(C2)、(C3)、(C5) |
| (I3) | limit_indices |
类型为 si64 的 1 维张量常量 |
(C2)、(C3)、(C5) |
| (I4) | strides |
类型为 si64 的 1 维张量常量 |
(C2)、(C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或按张量量化的张量 | (C1)、(C5) |
限制条件
- (C1)
element_type(operand) = element_type(result)。 - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)。 - (C3)
0 <= start_indices <= limit_indices <= shape(operand)。 - (C4)
0 < strides。 - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)。
示例
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
排序
语义
根据 comparator 沿维度 dimension 对 inputs 的一维切片进行排序,并生成 results。
与其他操作中的类似输入不同,dimension 允许使用负值,其语义如下所述。出于一致性考虑,未来可能会禁止此行为 (#1377)。
如果 is_stable 为 true,则排序是稳定的,也就是说,比较器认为相等的元素的相对顺序会保留。对于单个输入的情况,当且仅当 comparator(e1, e2) = comparator(e2, e1) = false 时,比较器才会认为两个元素 e1 和 e2 相等。如需了解如何将此概念推广到多个输入,请参阅下面的形式化表示。
更正式地说,对于 index_space(results[0]) 中的所有 result_index:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension。result_slice = [ri0, ..., :, ..., riR-1]其中riN是result_index中的各个元素,:插入在adjusted_dimension处。inputs_together = (inputs[0]..., ..., inputs[N-1]...)。results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)。- 其中,
sort以非降序对一维切片进行排序,并期望如果左侧实参小于右侧第二个实参,comparator_together返回true。 def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | inputs |
数量可变的张量或每个张量的量化张量 | (C1-C5) |
| (I2) | dimension |
类型为 si64 的常量 |
(C4) |
| (I3) | is_stable |
类型为 i1 的常量 |
|
| (I4) | comparator |
函数 | (C5) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
数量可变的张量或每个张量的量化张量 | (C2)、(C3) |
限制条件
- (C1)
0 < size(inputs)。 - (C2)
type(inputs...) = type(results...)。 - (C3)
same(shape(inputs...) + shape(results...))。 - (C4)
-R <= dimension < R,其中R = rank(inputs[0])。 - (C5)
comparator的类型为(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>,其中Ei = element_type(inputs[i])。
示例
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
语义
对 operand 张量执行逐元素平方根运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
squareRoot。 - 对于复数:复数平方根。
- 对于量化类型:
dequantize_op_quantize(sqrt, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
语义
对两个张量 lhs 和 rhs 执行逐元素减法,并生成一个 result 张量。根据元素类型,执行以下操作:
- 对于整数:整数减法。
- 对于浮点数:来自 IEEE-754 的
subtraction。 - 对于复数:复数减法。
- 对于量化类型:
dequantize_op_quantize(subtract, lhs, rhs, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
| (I2) | rhs |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
整数、浮点或复数类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
示例
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
语义
对 operand 张量执行逐元素正切运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
tan。 - 对于复数:复数正切。
- 对于量化类型:
dequantize_op_quantize(tan, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
语义
对 operand 张量执行逐元素双曲正切运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于浮点数:来自 IEEE-754 的
tanh。 - 对于复数:复数双曲正切。
- 对于量化类型:
dequantize_op_quantize(tanh, operand, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_type(operand) = baseline_type(result)。
示例
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transpose
语义
使用 permutation 对 operand 张量的维度进行置换,并生成 result 张量。更正式地说,result[result_index] = operand[operand_index],其中 result_index[d] = operand_index[permutation[d]]。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
张量或量化张量 | (C1-C4) |
| (I2) | permutation |
类型为 si64 的 1 维张量常量 |
(C2-C4) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
张量或量化张量 | (C1)、(C3-C4) |
限制条件
- (C1)
element_type(result)由以下人员提供:element_type(operand)(如果!is_per_axis_quantized(operand))。element_type(operand),但quantization_dimension(operand)和quantization_dimension(result)可能不同,否则。
- (C2)
permutation是range(rank(operand))的一种排列。 - (C3)
shape(result) = dim(operand, permutation...)。 - (C4) 如果
is_per_axis_quantized(result),则quantization_dimension(operand) = permutation(quantization_dimension(result))。
示例
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
语义
求解具有下三角或上三角系数矩阵的一批线性方程组。
更正式地说,给定 a 和 b,当 left_side 为 true 时,result[i0, ..., iR-3, :, :] 是 op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] 的解;当 left_side 为 false 时,result[i0, ..., iR-3, :, :] 是 op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] 的解。求解变量 x,其中 op(a) 由 transpose_a 确定,transpose_a 可以是以下值之一:x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
NO_TRANSPOSE:按原样使用a执行操作。TRANSPOSE:对a的转置执行操作。ADJOINT:对a的共轭转置执行运算。
如果 lower 为 true,则仅从 a 的下三角读取输入数据;否则,从 a 的上三角读取输入数据。输出数据在同一三角形中返回;另一三角形中的值由实现定义。
如果 unit_diagonal 为 true,则实现可以假定 a 的对角线元素等于 1,否则行为未定义。
对于量化类型,执行 dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | a |
浮点或复杂类型的张量或每个张量的量化张量 | (C1-C3) |
| (I2) | b |
浮点或复杂类型的张量或每个张量的量化张量 | (C1-C4) |
| (I3) | left_side |
类型为 i1 的常量 |
(C3) |
| (I4) | lower |
类型为 i1 的常量 |
|
| (I5) | unit_diagonal |
类型为 i1 的常量 |
|
| (I6) | transpose_a |
NO_TRANSPOSE、TRANSPOSE 和 ADJOINT 的枚举 |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点或复杂类型的张量或每个张量的量化张量 | (C1) |
限制条件
- (C1)
baseline_element_type(a) = baseline_element_type(b)。 - (C2)
2 <= rank(a) = rank(b) = R。 - (C3)
shape(a)与shape(b)之间的关系定义如下:shape(a)[:-3] = shape(b)[:-3]。dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)。
- (C4)
baseline_type(b) = baseline_type(result)。
示例
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
语义
根据值 val 生成 result 元组。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | val |
可变数量的值 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
tuple | (C1) |
限制条件
- (C1)
result的类型为tuple<E0, ..., EN-1>,其中Ei = type(val[i])。
示例
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (memref<2xf32>, tuple<tensor<i32>>) -> tuple<memref<2xf32>, tuple<tensor<i32>>>
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
语义
根据 operand 类型定义的量化参数,将量化张量 operand 逐元素转换为浮点张量 result。
更正式地说,result = dequantize(operand)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
量化张量 | (C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
浮点型张量 | (C1)、(C2) |
限制条件
- (C1)
shape(operand) = shape(result)。 - (C2)
element_type(result) = expressed_type(operand)。
示例
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
语义
根据 result 类型定义的量化参数,将浮点张量或量化张量 operand 逐元素转换为量化张量 result。
更正式地说,
- 如果
is_float(operand):result = quantize(operand, type(result))。
- 如果
is_quantized(operand):float_result = dequantize(operand)。result = quantize(float_result, type(result))。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
浮点或量化类型的张量 | (C1)、(C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
量化张量 | (C1)、(C2) |
限制条件
- (C1)
shape(operand) = shape(result)。 - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)。
示例
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
而
语义
在 cond 函数输出 true 时,执行 body 函数 0 次或多次,并生成输出。更正式地说,可以使用以下 Python 语法来表达语义:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
无限循环的行为尚待确定 (#383)。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | operand |
可变数量的值 | (C1-C3) |
| (I2) | cond |
函数 | (C1) |
| (I3) | body |
函数 | (C2) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
results |
可变数量的值 | (C3) |
限制条件
- (C1)
cond的类型为(T0, ..., TN-1) -> tensor<i1>,其中Ti = type(operand[i])。 - (C2)
body的类型为(T0, ..., TN-1) -> (T0, ..., TN-1),其中Ti = type(operand[i])。 - (C3)
type(results...) = type(operand...)。
示例
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
语义
对两个张量 lhs 和 rhs 执行逐元素异或运算,并生成 result 张量。根据元素类型,执行以下操作:
- 对于布尔值:逻辑 XOR。
- 对于整数:按位异或。
输入
| 标签 | 名称 | 类型 | 限制条件 |
|---|---|---|---|
| (I1) | lhs |
布尔值或整数类型的张量 | (C1) |
| (I2) | rhs |
布尔值或整数类型的张量 | (C1) |
输出
| 名称 | 类型 | 限制条件 |
|---|---|---|
result |
布尔值或整数类型的张量 | (C1) |
限制条件
- (C1)
type(lhs) = type(rhs) = type(result)。
示例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
方言互操作
目前,实际应用中的 StableHLO 程序有时包含 StableHLO 未定义的操作。
模块、函数、调用和返回
StableHLO 使用上游 MLIR 操作来处理 ModuleOp、FuncOp、CallOp 和 ReturnOp。这样做是为了更好地与现有的 MLIR 机制进行互操作,因为许多有用的 pass 都是以 FuncOp 和 ModuleOp 为目标编写的,并且许多编译流水线都希望这些操作存在。这些操作会应用完全兼容性保证。如果这些操作以不兼容的方式(即移除)发生任何变化,系统将添加 StableHLO 等效项以保持兼容性。
CHLO
CHLO 操作集包含可分解为 StableHLO 的更高级别操作。目前,我们不保证 CHLO 的兼容性。为了保证兼容性,必须在序列化之前使用 chlo-legalize-to-stablehlo 传递。
形状操作
在社区中,一个常见的用例是在动态 StableHLO 程序中使用核心 MLIR 方言中的某些操作来执行形状计算。最常见的包括 shape 方言操作(如 shape_of 或 num_elements)、tensor 方言操作(如 dim 或 from_elements)以及内置的 index 类型。
动态 RFC > O2 将这些类型标记为超出范围,但为了实现互操作性,其中包含对 index 类型的部分支持。我们不保证这些操作或类型的兼容性。shape-legalize-to-stablehlo 传递可用于将这些操作转换为完全受支持的 StableHLO 操作。
已弃用的操作
有几个 StableHLO 操作是从 MHLO 继承的,这些操作已被弃用,并且即将从 StableHLO 中移除。如需详细了解这些移除操作,请参阅 StableHLO v1.0 清理 #2283。 这些弃用的跟踪器问题为 #2340。
这些操作可分为以下几类:
- StableHLO 操作的“不在 HLO 中”类别 - 这些操作最初是 StableHLO 操作集的一部分,但后来被认为不太适合:
broadcast、create_token、cross-replica-sum、dot、einsum、torch_index_select、unary_einsum(#3)。 - 未使用的操作 - 这些操作可能在某个时间点有用,但这些操作要么未充分开发,要么使用这些操作的流水线已重构,不再需要这些操作。这包括
map、tuple(#598)、get_tuple_element、rng、complex比较 #560 和卷积window_reversal(#1181)。
鉴于这些操作可以使用现有操作(broadcast、create_token、cross-replica-sum、dot、unary_einsum)来表示,因此可以轻松移除它们,并且将在现有兼容性窗口期(6 个月)结束后移除。我们仍在探索是否移除其他操作(einsum、get_tuple_element、map、rng torch_index_select、tuple、complex 比较、window_reversal)。在收到社区反馈之前,这些操作要么会被移除,要么会添加到规范中并获得全面支持。在这些操作的未来版本已知之前,它们只能保证 6 个月的兼容性。
执行
顺序执行
StableHLO 程序的执行方式是:向 main 函数提供输入值,然后计算输出值。函数的输出值是通过执行以相应 return 操作为根的操作图计算得出的。
执行顺序由实现定义,只要它与数据流保持一致,即在操作使用之前执行。在 StableHLO 中,所有具有副作用的操作都会消耗一个令牌并生成一个令牌(多个令牌可以通过 after_all 多路复用到一个令牌中),因此副作用的执行顺序也与数据流保持一致。例如,在下面的程序中,有两种可能的执行顺序:%0 → %1 → %2 → return 和 %1 → %0 → %2 → return。
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
更正式地说,StableHLO 进程是以下各项的组合:1) StableHLO 程序;2) 操作状态(尚未执行、已执行);3) 进程正在处理的中间值。该流程从 main 函数的输入值开始,通过更新操作状态和中间值的运算图进行,最后以输出值结束。进一步的正式化工作待定 (#484)。
并行执行
StableHLO 程序可以并行执行,并组织成 num_replicas x num_partitions 的二维进程网格,其中 num_replicas 和 num_partitions 的类型均为 ui32。
在 StableHLO 进程网格中,num_replicas * num_partitions 个 StableHLO 进程同时执行。每个进程都有一个唯一的 process_id = (replica_id, partition_id),其中 replica_ids = range(num_replicas) 中的 replica_id 和 partition_ids = range(num_partitions) 中的 partition_id 都具有 ui32 类型。
每个程序(未来,我们计划将其明确纳入 StableHLO 程序 #650)的进程网格大小都是静态已知的,并且每个进程在进程网格中的位置也是静态已知的。每个进程都可以通过 replica_id 和 partition_id 操作访问其在进程网格中的位置。
在进程网格中,所有程序可以相同(“单程序、多数据”样式),也可以全部不同(“多程序、多数据”样式),或者介于两者之间。未来,我们计划引入对其他定义并行 StableHLO 程序的惯用语法的支持,包括 GSPMD (#619)。
在进程网格中,进程之间大多是相互独立的 - 它们具有单独的运行状态、单独的输入/中间/输出值,并且大多数操作都是在进程之间单独执行的,但少数集体操作除外(如下所述)。
鉴于大多数操作的执行仅使用来自同一进程的值,因此通常可以通过名称明确引用这些值。不过,在描述集合运算的语义时,这种表示法是不够的,因此我们引入了 name@process_id 符号来表示特定进程中的值 name。(从这个角度来看,不带限定词的 name 可以视为 name@(replica_id(), partition_id()) 的简写形式)。
各进程之间的执行顺序由实现定义,但点对点通信和集体操作引入的同步除外,如下所述。
点对点通信
StableHLO 进程可以通过 StableHLO 渠道相互通信。渠道由 si64 类型的正 ID 表示。通过各种操作,可以向渠道发送值并从渠道接收值。
进一步的正式化(例如,这些渠道 ID 的来源、进程程序如何了解它们以及它们引入了哪种同步)待定 (#484)。
流式通信
每个 StableHLO 进程都可以访问两个流式接口:
- 可从中读取数据的信息源。
- 可写入的出料。
与用于在进程之间通信的通道不同,infeed 和 outfeed 的另一端实现由实现定义。通道的两端都有进程,
进一步的正式化(例如,流式通信如何影响执行顺序以及它引入了哪种同步)待定 (#484)。
集体操作
StableHLO 中有六个集合运算:all_gather、all_reduce、all_to_all、collective_broadcast、collective_permute 和 reduce_scatter。所有这些操作都会将 StableHLO 进程网格中的进程拆分为 StableHLO 进程组,并在每个进程组内执行联合计算,与其他进程组无关。
在每个进程组中,集合运算可能会引入同步屏障。进一步的正式化(例如,详细说明此同步的确切发生时间、进程到达此屏障的确切方式以及如果进程未到达此屏障会发生什么情况)尚待确定 (TBD) (#484)。
如果进程组涉及跨分区通信(即,进程组中存在分区 ID 不同的进程),则集体操作的执行需要一个通道,并且集体操作必须提供一个正 channel_id(类型为 si64)。跨副本通信不需要渠道。
集体操作执行的计算特定于各个操作,并在上面的各个操作部分中进行了说明。不过,将进程网格拆分为进程组的策略在这些操作之间是共享的,本部分将对此进行介绍。更正式地说,StableHLO 支持以下四种策略。
cross_replica
每个进程组内仅发生跨副本通信。此策略接受 replica_groups(副本 ID 的列表的列表),并通过 replica_groups 计算 partition_ids 的笛卡尔积。replica_groups 必须具有唯一元素,并且涵盖所有 replica_ids。更正式地说,使用 Python 语法:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,对于 replica_groups = [[0, 1], [2, 3]] 和 num_partitions = 2,cross_replica 将生成 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]。
cross_partition
每个进程组内仅发生跨分区通信。此策略接受 partition_groups(分区 ID 的列表的列表),并计算 partition_groups 与 replica_ids 的笛卡尔积。partition_groups 必须具有唯一元素,并且涵盖所有 partition_ids。更正式地说,使用 Python 语法:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,对于 partition_groups = [[0, 1]] 和 num_replicas = 4,cross_partition 将生成 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]。
cross_replica_and_partition
每个进程组内都可能会发生跨副本和跨分区通信。此策略接受 replica_groups(副本 ID 的列表的列表),并计算每个 replica_group 与 partition_ids 的笛卡尔积。replica_groups 必须具有唯一元素,并涵盖所有 replica_ids。更正式地说,使用 Python 语法:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
例如,对于 replica_groups = [[0, 1], [2, 3]] 和 num_partitions = 2,cross_replica_and_partition 将生成 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]。
flattened_ids
此策略接受 flattened_id_groups(“扁平化”进程 ID 的列表的列表,格式为 replica_id * num_partitions + partition_id),并将其转换为进程 ID。flattened_id_groups 必须具有唯一元素,并涵盖所有 process_ids。更正式地说,使用 Python 语法:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
例如,对于 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]、num_replicas = 4 和 num_partitions = 2,flattened_ids 将生成 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]。
准确率
目前,StableHLO 不提供有关数值准确性的保证,但未来可能会发生变化 (#1156)。
量化操作的执行语义
量化 StableHLO 操作的解读可能会因硬件要求和功能而异。例如,某些硬件可能会选择使用“去量化、执行浮点运算,最后量化”策略来解释量化操作。而另一些则可能使用整数运算执行整个计算。因此,量化 StableHLO 操作的解释完全取决于具体实现。混合量化 (#1575) 的解读应基于规范(通过 1792)中规定的语义。
错误
StableHLO 程序通过针对各个操作的一组广泛的约束条件进行验证,从而在运行时之前排除许多类别的错误。不过,仍有可能出现错误情况,例如通过整数溢出、越界访问等。除非明确指出,否则所有这些错误都会导致实现定义的行为,但未来可能会发生变化 (#1157)。
浮点异常
不过,StableHLO 程序中的浮点异常具有明确定义的行为。导致 IEEE-754 标准定义的异常(无效运算、除以零、溢出、下溢或不精确异常)的运算会产生默认结果(如标准中所定义),并继续执行,而不会引发相应的状态标志;这与标准中的 raiseNoFlag 异常处理类似。非标准运算(例如复杂算术运算和某些超越函数)的例外情况由实现定义。
形状不匹配
StableHLO 支持动态形状的张量。不过,形状必须在运行时保持一致,否则行为将处于未定义状态。StableHLO 没有明确提供可在运行时断言张量具有给定形状的运算。生成正确的代码是制作方的责任。
举个具体的例子,以下程序是有效的。不过,在运行时,%arg0 和 %arg1 的确切形状必须相同,否则程序的行为将是不确定的:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
为了描述语法,本文档使用了经过修改的 ISO 版 EBNF 语法 (ISO/IEC 14977:1996、Wikipedia),并进行了两项修改:1) 使用 ::= 而不是 = 定义规则,
2) 连接使用并列表示,而不是 ,。
为了描述语义(即在“类型”“常量”和“操作”部分中),我们使用了基于 Python 语法的公式,该语法扩展了对简洁表达数组操作的支持,如下所述。对于小段代码,这种方法效果很好,但在极少数情况下需要较大段代码时,我们会使用纯 Python 语法,并且始终会明确说明。
公式
我们来根据规范中的示例探讨一下公式的运作方式。dot_general此操作的限制之一如下所示:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)。
此公式中使用的名称来自两个来源:1) 全局函数,即 dim;2) 相应程序元素的成员定义,即 dot_general 的“输入”部分中定义的 lhs、lhs_batching_dimensions、rhs 和 rhs_batching_dimensions 输入。
如上所述,此公式的语法基于 Python,并添加了一些以简洁为导向的扩展功能。为了理解该公式,我们将其转换为普通的 Python 语法。
A) 在这些公式中,我们使用 = 表示相等,因此获得 Python 语法的第一个步骤是将 = 替换为 ==,如下所示:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)。
B) 此外,这些公式还支持省略号 (...),可将标量表达式转换为张量表达式。简而言之,f(xs...) 大致是指“对于张量 xs 中的每个标量 x,计算一个标量 f(x),然后将所有这些标量结果一起作为张量结果返回”。在标准 Python 语法中,我们的示例公式变为:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]。
借助省略号,通常可以避免在单个标量的级别上进行操作。不过,在某些棘手的情况下,可以使用较低级别的半正式语法,如 gather 规范中的 start_indices[bi0, ..., :, ..., biN] 公式所示。为了简洁起见,我们没有提供将此类语法转换为纯 Python 的确切形式化方法,希望它在具体情况下仍然可以直观地理解。如果您觉得某些特定公式不够清晰,请告诉我们,我们会尽力改进。
此外,您会注意到,公式使用省略号来展开各种列表,包括张量、张量列表(例如,可能来自可变数量的张量)等。这是我们未提供确切形式主义的另一个领域(例如,列表甚至不是 StableHLO 类型系统的一部分),而是依赖于直观的可理解性。
C) 我们使用的最后一个值得注意的符号表示法是隐式广播。虽然 StableHLO 操作集不支持隐式广播,但公式支持,这也是为了简洁起见。简而言之,如果在需要张量的上下文中使用标量,则标量会广播到预期形状。
继续沿用 dot_general 示例,我们再添加一个限制条件:0 <= lhs_batching_dimensions < rank(lhs)。根据 dot_general 规范中的定义,lhs_batching_dimensions 是一个张量,而 0 和 rank(lhs) 都是标量。应用隐式广播后,公式将变为 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]。
当应用于特定的 dot_general 操作时,此公式将评估为布尔值张量。当公式用作限制条件时,如果公式的计算结果为 true 或仅包含 true 元素的张量,则限制条件成立。
名称
在公式中,词法范围包括:1) 全局函数,2) 成员定义,
3) 本地定义。下面列出了全局函数。元素定义列表取决于应用了相应表示法的程序元素:
- 对于操作,成员定义包括“输入”和“输出”部分中引入的名称。
- 对于其他所有内容,成员定义都包含程序元素的结构部分,并以相应的 EBNF 非终端命名。大多数情况下,这些结构部分的名称是通过将非终端的名称转换为蛇形命名法(例如
IntegerLiteral=>integer_literal)获得的,但有时名称会在该过程中被缩写(例如QuantizationStorageType=>storage_type),在这种情况下,名称会像操作规范中的“输入”/“输出”部分一样被明确引入。 - 此外,成员定义始终包含
self以引用相应的程序元素。
值
在评估公式时,它们会使用以下类型的值:1) Value(实际值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;它们始终知道自己的类型),2) Placeholder(未来值,例如 lhs、rhs 或 result;它们的实际值尚不清楚,只知道它们的类型),3) Type(“类型”部分中定义的类型),4) Function(“函数”部分中定义的全局函数)。
根据上下文的不同,名称可能指的是不同的值。更具体地说,操作的“语义”部分(以及其他程序元素的等效部分)定义了运行时逻辑,因此所有输入都以 Value 的形式提供。相比之下,操作(以及等效项)的“限制”部分定义了“编译时”逻辑,即通常在运行时之前执行的逻辑,因此只有常量输入可用作 Value,而其他输入仅可用作 Placeholder。
| 名称 | 在“语义”中 | 在“限制”中 |
|---|---|---|
| 全局函数 | Function |
Function |
| 常量输入 | Value |
Value |
| 非常量输入 | Value |
Placeholder |
| 输出 | Value |
Placeholder |
| 本地定义 | 取决于定义 | 取决于定义 |
我们来看一个 transpose 操作示例:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
对于此操作,permutation 是一个常量,因此在语义和约束中均可作为 Value 使用。相比之下,operand 和 result 在语义上作为 Value 提供,但在限制方面仅作为 Placeholder 提供。
函数
类型的构造
没有可用于构建类型的函数。相反,我们直接使用类型语法,因为这种语法通常更简洁。例如,(tensor<E>, tensor<E>) -> (tensor<E>) 而不是 function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])。
类型函数
element_type是针对张量类型和量化张量类型定义的,分别返回相应TensorType或QuantizedTensorType的TensorElementType或QuantizedTensorElementType部分。
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value是is_quantized(x) and quantization_dimension(x) is not None的快捷方式。is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value是is_quantized(x) and quantization_dimension(x) is None的快捷方式。is_promotable(x: Type, y: Type) -> bool检查类型x是否可以提升为类型y。如果x和y均为QuantizedTensorElementType,则促销活动仅适用于storage_type。此特定版本的促销活动目前用于计算减免(有关详情,请参阅 RFC)。
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value是is_quantized_tensor_element_type(x)的快捷方式。is_type_name(x: Value | Placeholder | Type) -> Value。适用于所有类型。例如,如果x为FloatType,则is_float(x)会返回true。 如果x是值或占位符,则此函数是is_type_name(type(x))的快捷方式。max_value(x: Type) -> Value返回TensorElementType的最大值。如果x不是TensorElementType,则返回None。min_value(x: Type) -> Value返回TensorElementType的最小值。如果x不是TensorElementType,则返回None。member_name(x: Value | Placeholder | Type) -> Any。适用于所有类型的全部成员定义member_name。例如,tensor_element_type(x)会返回相应TensorType的TensorElementType部分。如果x是值或占位符,则此函数是member_name(type(x))的快捷方式。如果x不是具有相应成员的类型,也不是此类类型的值或占位符,则返回None。is_empty_algorithm(*args: Type)检查所有点算法字段是否都设置为None。这是必需的,因为点算法具有实现定义的默认行为,因此指定默认值是不正确的。
值的构建
operation_name(*xs: Value | Type) -> Value。适用于所有操作。 例如,add(lhs, rhs)接受两个张量值lhs和rhs,并返回使用这些输入评估add操作的输出。对于某些操作(例如broadcast_in_dim),其输出的类型是“承载”类型,即需要用于评估操作。在这种情况下,该函数将这些类型作为实参。
值函数
所有 Python 运算符和函数均可使用。例如,Python 中的订阅和切片表示法均可用于为张量、量化张量和元组编制索引。
to_destination_type(x: Value, destination_type: Type) -> Value是针对张量定义的,并根据type(x)和destination_type返回x的转换值,如下所示:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
目前正在初步讨论合并 convert、uniform_quantize 和 uniform_dequantize 操作 (#1576)。合并后,我们不再需要上述函数,而是可以使用 convert 的操作名称。
is_nan(x: Value) -> Value是针对张量定义的,如果x的所有元素均为NaN,则返回true,否则返回false。如果x不是张量,则返回None。is_sorted(x: Value) -> Value是针对张量定义的,如果x的元素按其索引的升序字典顺序以升序排序,则返回true,否则返回false。如果x不是张量,则返回None。is_unique(x: Value) -> Value是针对张量定义的,如果x没有重复元素,则返回true,否则返回false。如果x不是张量,则返回None。member_name(x: Value) -> Any针对所有值的全部成员定义member_name进行定义。例如,real_part(x)会返回相应ComplexConstant的RealPart部分。如果x不是具有相应成员的值,则返回None。same(x: Value) -> Value是针对张量定义的,如果x的元素全部相等,则返回true;否则返回false。如果张量没有元素,则视为“全部彼此相等”,即函数返回true。如果x不是张量,则返回None。split(x: Value, num_results: Value, axis: Value) -> Value是针对张量定义的,沿轴axis返回x的num_results切片。如果x不是张量或dim(x, axis) % num_results != 0,则返回None。is_defined_in_parent_scope(x: Value) -> Value是针对字符串定义的,如果x是在与相关操作的父函数相同的作用域中定义的函数的名称,则返回true。is_namespaced_op_name(x: Value) -> Value是针对字符串定义的,如果x是有效的操作名称(即符合以下正则表达式:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+),则返回true
形状计算
axes(x: Value | Placeholder | Type) -> Value是range(rank(x))的快捷方式。dim(x: Value | Placeholder | Type, axis: Value) -> Value是shape(x)[axis]的快捷方式。dims(x: Value | Placeholder | Type, axes: List) -> List是list(map(lambda axis: dim(x, axis), axes))的快捷方式。index_space(x: Value | Placeholder | Type) -> Value是针对张量定义的,并返回按升序字典顺序排序的相应TensorType的size(x)索引,即[0, ..., 0]、[0, ..., 1]、...、shape(x) - 1。如果x不是张量类型、量化张量类型或这些类型的值/占位符,则返回None。rank(x: Value | Placeholder | Type) -> Value是size(shape(x))的快捷方式。shape(x: Value | Placeholder | Type) -> Value是通过member_name在“类型函数”部分中定义的。size(x: Value | Placeholder | Type) -> Value是reduce(lambda x, y: x * y, shape(x))的快捷方式。
量化计算
def baseline_element_type(x: Value | Placeholder | Type) -> Type是element_type(baseline_type(x))的快捷方式。baseline_type是针对张量类型和量化张量类型定义的,可将它们转换为“基准”,即形状相同但元素类型的量化参数重置为默认值的类型。这是一种便捷的技巧,可用于统一比较张量和量化张量类型,这在很多情况下都是必需的。对于量化类型,这允许比较类型时忽略量化参数,也就是说,shape、storage_type、expressed_type、storage_min、storage_max和quantization_dimension(对于按轴量化类型)必须全部匹配,但scales和zero points可能不同。
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize是针对量化张量类型定义的,可将其转换为浮点张量类型。这是通过以下方式实现的:使用与量化元素类型关联的零点和缩放比例,将表示存储类型整数值的量化元素转换为表示类型的相应浮点值。
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize是针对浮点张量类型定义的,可将其转换为量化张量类型。这是通过以下方式实现的:使用与量化元素类型相关联的零点和比例,将所表达类型的浮点值转换为存储类型的相应整数值。
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize用于指定对量化张量进行逐元素计算。它会进行反量化(即将量化元素转换为其表达类型),然后执行运算,再进行量化(即将结果转换回其存储类型)。目前,此函数仅适用于按张量量化。正在开发按轴量化功能 (#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op用于为混合运算指定仅权重量化,该运算接受浮点类型的左侧实参和量化类型的右侧实参。它将量化输入反量化为表达类型,并以浮点数执行计算。浮点左侧实参张量的元素类型和量化右侧实参张量的表达类型应相同。
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
网格计算
cross_partition(replica_groups: Value) -> Value。请参阅上文中的“cross_replica”部分。cross_replica(replica_groups: Value) -> Value。请参阅上文中的“cross_replica”部分。cross_replica_and_partition(replica_groups: Value) -> Value。请参阅上文中的“cross_replica_and_partition”部分。flattened_ids(replica_groups: Value) -> Value. 请参阅上文中的“flattened_ids”部分。
活力
StableHLO 值可以具有动态维度大小,例如 tensor<?xi64>。不过,StableHLO 值不能具有动态维度数量(无秩动态性,例如 tensor<*xi64>)。操作数和结果可以使用动态维度大小,即使大小存在限制也是如此。如果可能,系统将静态验证约束条件,否则会将约束条件延迟到运行时进行验证,不匹配会导致未定义的行为。如需查看示例,请参阅下文。
一元逐元素运算的形状不匹配
请考虑以下玩具程序:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
此类程序并不常见,因为通常情况下,我们知道结果的形状,但不知道输入的形状。不过,这是一个有效的 StableHLO 程序。无法在此程序中静态验证 abs 操作,因为操作数的具体形状未知。不过,这些形状肯定是兼容的,并且可以静态检查:? 在运行时可能变为 2,但不会出现任何问题。不过,? 也可能是其他某个整数,在这种情况下,行为是未定义的。
请注意,如果结果中的维度大小是动态的,则不能出现未定义的行为。实际上,没有“预期”大小,因此不会出现不匹配的情况。
二元按元素运算的形状不匹配
请考虑以下玩具程序:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
对于二元逐元素运算,输入和结果的形状必须在运行时保持一致。在编译时,静态维度必须相等,否则只需兼容即可。如果输入中存在任何动态维度,则运行时可能会出现未定义的行为,因为动态大小可能与另一个操作数(无论是静态还是动态)中的相应大小不匹配。如果所有输入都是静态的,那么结果是动态的还是静态的并不重要:静态已知维度将进行静态检查,而动态维度不会施加任何限制。
将输出形状作为操作数的运算的形状不匹配
请考虑以下玩具程序:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
运行时,形状操作数中的值必须与结果的形状一致,否则行为未定义。也就是说,在运行时,%arg0 的值必须为 dense<[3, 4]> : tensor<2xi32>。如果形状运算数是常量,则可以静态验证这一点。如果结果形状完全动态,则不会出现不匹配的情况。