StableHLO 规范

StableHLO 是机器学习 (ML) 模型中高级操作 (HLO) 的操作集。StableHLO 可用作不同机器学习框架和机器学习编译器之间的可移植性层:生成 StableHLO 程序的机器学习框架与使用 StableHLO 程序的机器学习编译器兼容。

我们的目标是通过在各种机器学习框架(例如 TensorFlow、JAX 和 PyTorch)和机器学习编译器(例如 XLA 和 IREE)之间实现更高的互操作性,从而简化和加速机器学习开发。为此,本文档提供了 StableHLO 编程语言的规范。

本规范包含三个主要部分。首先,程序部分介绍了 StableHLO 程序的结构,其中包含 StableHLO 函数,这些函数本身又包含 StableHLO 运算。在该结构中,操作部分指定了各个操作的语义。执行部分为程序中一起执行的所有这些操作提供了语义。最后,表示法部分讨论了整个规范中使用的表示法。

如需查看 StableHLO 的旧版本中的规范,请打开相应标记版本下的代码库。例如,StableHLO v0.19.0 规范。如需查看 StableHLO 的每个次次要版本升级时发生的更改,请参阅 VhloDialect.td 中的版本日志。

计划

Program ::= {Func}

StableHLO 程序由任意数量的 StableHLO 函数组成。以下是一个示例程序,其中包含一个函数 @main,该函数有 3 个输入(%image%weights%bias)和 1 个输出。函数正文包含 6 个操作。

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

函数

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 函数(也称为“命名函数”)具有标识符、输入/输出和正文。未来,我们计划为函数引入其他元数据,以实现与 HLO 的更好兼容性 (#425#626#740#744)。

标识符

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO 标识符与许多编程语言中的标识符类似,但有两个特点:1) 所有标识符都有用于区分不同类型标识符的符号,2) 值标识符可以是纯数字,以简化 StableHLO 程序的生成。

类型

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 类型分为值类型(也称为第一类类型),用于表示 StableHLO 值,以及非值类型,用于描述其他程序元素。StableHLO 类型与许多编程语言中的类型类似,主要的特殊之处在于 StableHLO 的领域专用性质,这会导致一些不寻常的结果(例如,标量类型不是值类型)。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

张量类型表示张量,即多维数组。它们具有形状元素类型,其中形状表示从 0R-1 编号的相应维度(也称为)的升序非负或未知维度大小。维度数 R 称为“秩”。例如,tensor<2x3xf32> 是一种形状为 2x3 且元素类型为 f32 的张量类型。它有两个维度(换句话说,两个轴)- 第 0 个维度和第 1 个维度,其大小分别为 2 和 3。其排名为 2。

形状可以是部分或完全未知(动态),例如 tensor<?x2xf64> 是部分未知,tensor<?x?xf64> 是完全未知。动态尺寸使用 ? 表示。形状无法排序。

未来,我们计划探索将张量类型扩展到维度大小和元素类型之外,例如,添加布局 (#629) 和稀疏性 (#1078)。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
名称 类型 限制条件
storage_type 整数类型 (C1-C3)、(C8)
storage_min 整数常量 (C1)、(C3)、(C7)
storage_max 整数常量 (C2)、(C3)、(C7)
expressed_type 浮点类型 (C4)
quantization_dimension 可选整数常量 (C10-C12)
scales 可变数量的浮点常量 (C4-C6)、(C9)、(C10)、(C13)
zero_points 整数常量可变数量 (C7-C9)

量化元素类型表示 storage_minstorage_max(包括这两个数值)范围内的存储类型的整数值,这些值与表达式类型的浮点值相对应。对于给定的整数值 i,对应的浮点值 f 可计算为 f = (i - zero_point) * scale,其中 scalezero_point 称为量化参数storage_minstorage_max 在语法中是可选的,但默认值分别为 min_value(storage_type)max_value(storage_type)。量化元素类型具有以下约束条件:

  • (C1) type(storage_min) = storage_type
  • (C2) type(storage_max) = storage_type
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C4) type(scales...) = expressed_type
  • (C5) 0 < scales
  • (C6) is_finite(scales...)
  • (C7) storage_min <= zero_points <= storage_max
  • (C8) type(zero_points...) = storage_type
  • (C9) size(scales) = size(zero_points)
  • (C10) 如果 is_empty(quantization_dimension),则 size(scales) = 1
  • (C11) 0 <= quantization_dimension

目前,QuantizationScale 是一个浮点常量,但人们对基于整数的比例(通过乘数和移位表示)非常感兴趣。我们计划在不久的将来探索这一点 (#1404)。

我们正在就 QuantizationZeroPoint 的语义(包括类型、值以及量化张量类型中是否只能有一个或可能有多个零点)进行讨论。根据本次讨论的结果,关于零点的规范未来可能会发生变化 (#1405)。

另一个正在讨论的问题涉及 QuantizationStorageMinQuantizationStorageMax 的语义,以确定是否应对这些值和量化张量的值施加任何约束条件 (#1406)。

最后,我们计划探索如何表示未知的比例和零点,这与我们计划探索如何表示未知的维度大小类似 (#1407)。

量化张量类型表示包含量化元素的张量。这些张量与常规张量完全相同,只是其元素具有量化元素类型,而不是常规元素类型。

在量化张量中,量化可以是按张量进行,即为整个张量使用一个 scalezero_point;也可以是按轴进行,即为特定维度 quantization_dimension 的每个 slice 使用一对 scaleszero_points。更正式地说,在采用按轴量化方式的张量 t 中,quantization_dimensiondim(t, quantization_dimension) 个 slice:t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] 等。i 个 slice 中的所有元素都使用 scales[i]zero_points[i] 作为其量化参数。量化张量类型具有以下限制:

  • 对于每个张量量化:
    • 没有其他约束条件。
  • 对于按轴量化:
    • (C12) quantization_dimension < rank(self)
    • (C13) dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

令牌类型表示令牌,即某些操作生成和使用的不透明值。令牌用于对操作强制执行顺序,如执行部分所述。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

元组类型表示元组,即异构列表。元组是一种旧版功能,仅出于与 HLO 的兼容性考虑而存在。在 HLO 中,元组用于表示可变参数输入和输出。在 StableHLO 中,系统原生支持可变参数输入和输出,并且 StableHLO 中对元组的唯一用途是全面表示 HLO ABI,其中 Ttuple<T>tuple<tuple<T>> 可能会因具体实现而存在显著差异。未来,我们计划对 HLO ABI 进行更改,这或许可以从 StableHLO 中移除元组类型 (#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

元素类型表示张量类型的元素。与许多编程语言不同,这些类型在 StableHLO 中不是第一类。这意味着 StableHLO 程序无法直接表示这些类型的值(因此,通常使用类型为 tensor<T> 的 0 维张量值来表示类型为 T 的标量值)。

  • 布尔值类型表示布尔值 truefalse
  • 整数类型可以是有符号(si)或无符号(ui),并且具有支持的位宽(248163264)之一。有符号 siN 类型表示介于 -2^(N-1)2^(N-1)-1(包括这两个数值)之间的整数值,无符号 uiN 类型表示介于 02^N-1(包括这两个数值)之间的整数值。
  • 浮点类型可以是以下任一项:
  • 复杂类型表示具有相同元素类型的实部和虚部的复杂值。支持的复杂类型包括 complex<f32>(两个部分均为 f32 类型)和 complex<f64>(两个部分均为 f64 类型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

函数类型同时表示命名和匿名函数。它们具有输入类型(-> 左侧的类型列表)和输出类型(-> 右侧的类型列表)。在许多编程语言中,函数类型是第一类,但在 StableHLO 中却不是。

StringType ::= 'string'

字符串类型表示字节序列。与许多编程语言不同,字符串类型在 StableHLO 中不是第一类,仅用于为节目元素指定静态元数据。

操作

StableHLO 操作(也称为“操作”)代表机器学习模型中一组封闭的高级操作。如上所述,StableHLO 语法深受 MLIR 的启发,MLIR 不一定是最符合人体工学的替代方案,但可以说最适合 StableHLO 的目标,即在机器学习框架和机器学习编译器之间实现更高的互操作性。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO 操作(也称为操作)具有名称、输入/输出和签名。该名称由 stablehlo. 前缀和一个记忆法符号组成,用于唯一标识某个受支持的操作。请参阅下文,查看所有受支持操作的完整列表。

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

操作会使用输入并生成输出。输入分为输入值(在执行期间计算)、输入函数(静态提供,因为在 StableHLO 中,函数不是一级值)和输入属性(也以静态方式提供)。运算所消耗和生成的输入和输出的类型取决于其助记符。例如,add 运算会使用 2 个输入值并生成 1 个输出值。相比之下,select_and_scatter 运算会使用 3 个输入值、2 个输入函数和 3 个输入属性。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

输入函数(也称为匿名函数)与命名函数非常相似,区别在于:1) 输入函数没有标识符(因此称为“匿名”);2) 输入函数不声明输出类型(从函数内的 return 操作推断输出类型)。

输入函数的语法包含一个目前未使用的部分(请参阅上文中的 Unused 生产规则),该部分是为了与 MLIR 保持兼容性。在 MLIR 中,有更通用的“区域”概念,区域可以通过跳转操作将多个操作“块”连接在一起。这些块具有与 Unused 产生式相对应的 ID,因此可以彼此区分。StableHLO 没有跳转操作,因此 MLIR 语法的相应部分未使用(但仍保留)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

输入属性包含名称和值,此值是受支持的常量之一。它们是为节目元素指定静态元数据的主要方式。例如,concatenate 运算使用属性 dimension 指定其输入值沿着哪个维度串联。同样,slice 运算使用 start_indiceslimit_indices 等多个属性来指定用于切片输入值的边界。

目前,实际环境中的 StableHLO 程序有时包含本文档中未介绍的属性。未来,我们计划将这些属性纳入 StableHLO 操作集,或者禁止它们出现在 StableHLO 程序中。在此期间,请参考以下这些属性的列表:

  • layout (#629)。
  • mhlo.frontend_attributes (#628)。
  • mhlo.sharding (#619)。
  • output_operand_aliases (#740)。
  • 位置元数据 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

操作签名由所有输入值的类型(-> 左侧的类型列表)和所有输出值的类型(-> 右侧的类型列表)组成。严格来说,输入类型是多余的,输出类型几乎也总是多余的(因为对于大多数 StableHLO 操作,输出类型可以从输入中推断出来)。不过,操作签名特意成为 StableHLO 语法的一部分,以便与 MLIR 兼容。

以下是一个示例操作,其助记符为 select_and_scatter。它会使用 3 个输入值(%operand%source%init_value)、2 个输入函数和 3 个输入属性(window_dimensionswindow_stridespadding)。请注意,该操作的签名仅包含其输入值的类型(而不包括以内嵌方式提供的输入函数和属性的类型)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

常量

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 常量具有一个字面量和一个类型,它们共同表示一个 StableHLO 值。通常,类型是常量语法的一部分,除非它是明确的(例如,布尔常量明确为类型 i1,而整数常量可以有多种可能的类型)。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

布尔常量表示布尔值 truefalse。布尔常量的类型为 i1

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整数常量通过使用十进制或十六进制表示法的字符串表示整数值。不支持其他进制(例如二进制或八进制)。整数常量具有以下限制:

  • (C1) is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮点常量:通过使用十进制或科学计数法的字符串表示浮点值。此外,十六进制表示法还可用于直接指定相应类型的浮点格式的底层位。浮点常量具有以下限制:

  • (C1) 如果使用非十六进制表示法,is_wellformed(float_literal, float_type)
  • (C2) 如果使用十六进制表示法,size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

复数常量使用实部(位于前面)和虚部(位于后面)的列表来表示复值。例如,(1.0, 0.0) : complex<f32> 表示 1.0 + 0.0i(0.0, 1.0) : complex<f32> 表示 0.0 + 1.0i。然后,这些部分在内存中的存储顺序由实现定义。复杂常量具有以下限制:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type))
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

张量常量,使用通过 NumPy 表示法指定的嵌套列表来表示张量值。例如,dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 表示一个张量值,其中索引到元素的映射如下所示:{0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6。然后,这些元素在内存中的存储顺序由实现定义。张量常量具有以下约束条件:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)),其中:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
  • (C2) has_shape(tensor_literal, shape(tensor_type)),其中:
    • has_shape(element_literal: Syntax, []) = true
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
    • 否则,false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

量化张量常量使用与张量常量相同的符号表示量化张量值,其中元素被指定为其存储类型的常量。量化张量常量具有以下约束条件:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

字符串字面量由使用 ASCII 字符和转义序列指定的字节组成。它们不依赖于编码,因此对这些字节的解读由实现定义。字符串字面量的类型为 string

Ops Agent 可以

abs

语义

operand 张量执行元素级绝对值运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于有符号整数:整数模。
  • 对于浮点数:IEEE-754 中的 abs
  • 对于复数:复模数。
  • 对于量化类型:dequantize_op_quantize(abs, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 有符号整数、浮点数或复杂类型的张量,或每个张量量化张量 (C1-C2)

输出

名称 类型 限制条件
result 有符号整数或浮点类型的张量,或每个张量量化的张量 (C1 - C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) baseline_element_type(result) 的定义如下:
    • 如果 is_complex(operand),则设为 complex_element_type(element_type(operand))
    • 否则为 baseline_element_type(operand)

示例

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

更多示例

add

语义

对两个张量 lhsrhs 执行元素级加法,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑或。
  • 对于整数:整数加法。
  • 对于浮点数:IEEE-754 中的 addition
  • 对于复数:复数加法。
  • 对于量化类型:dequantize_op_quantize(add, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或量化张量 (C1-C6)
(I2) rhs 张量或量化张量 (C1-C5)、(C7)

输出

名称 类型 限制条件
result 张量或量化张量 (C1 - C7)

限制条件

  • 如果运算使用非量化张量:
    • (C1) type(lhs) = type(rhs) = type(result)
  • 如果操作使用量化张量:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result)
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
    • (C6) 如果 is_per_axis_quantized(lhs),则 quantization_dimension(lhs) = quantization_dimension(result)
    • (C7) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs) = quantization_dimension(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

更多示例

after_all

语义

确保在依赖于 result 的任何操作之前执行生成 inputs 的操作。执行此操作不会执行任何操作,它只是用于建立从 resultinputs 的数据依赖项。

输入

标签 名称 类型
(I1) inputs token的可变数

输出

名称 类型
result token

示例

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

更多示例

all_gather

语义

在 StableHLO 进程网格中的每个进程组内,沿着 all_gather_dim 连接来自每个进程的 operands 张量的值,并生成 results 张量。

该操作将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0 and use_global_device_ids = false,则设为 cross_replica(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = false,则设为 cross_replica_and_partition(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = true,则设为 flattened_ids(replica_groups)

之后,在每个 process_group 中:

  • process_group 中的所有 receiver 执行 operands...@receiver = [operand@sender for sender in process_group]
  • process_group 中的所有 process 均为 results...@process = concatenate(operands...@process, all_gather_dim)

输入

标签 名称 类型 限制条件
(I1) operands 可变数量张量或每个张量量化张量 (C1), (C6)
(I2) all_gather_dim 类型为 si64 的常量 (C1)、(C6)
(I3) replica_groups 类型为 si64 的二维张量常量 (C2-C4)
(I4) channel_id si64 类型的常量 (C5)
(I5) use_global_device_ids 类型为 i1 的常量 (C5)

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C6)

限制条件

  • (C1) 0 <= all_gather_dim < rank(operands...)
  • (C2) is_unique(replica_groups)
  • (C3) size(replica_groups) 的定义如下:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C4) 0 <= replica_groups < size(replica_groups)
  • (C5) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C6) type(results...) = type(operands...),但以下情况除外:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)

示例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

更多示例

all_reduce

语义

在 StableHLO 进程网格的每个进程组中,对来自每个进程的 operands 张量值应用归约函数 computation,并生成 results 张量。

该操作会将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0 and use_global_device_ids = false,则设为 cross_replica(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = false,则设为 cross_replica_and_partition(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = true,则设为 flattened_ids(replica_groups)

之后,在每个 process_group 中:

  • results...@process[result_index] = exec(schedule) 表示某个二元树 schedule,其中:
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是由实现定义的二元树,其有序遍历为 to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))

输入

标签 名称 类型 限制条件
(I1) operands 张量数量可变或每个张量量化张量 (C5), (C6)
(I2) replica_groups 类型为 si64 的一维张量常量的可变数量 (C1-C3)
(I3) channel_id 类型为 si64 的常量 (C4)
(I4) use_global_device_ids 类型为 i1 的常量 (C4)
(I5) computation 函数 (C5)

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C6-C7)

限制条件

  • (C1) is_unique(replica_groups)
  • (C2) size(replica_groups) 的定义如下:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C3) 0 <= replica_groups < size(replica_groups)
  • (C4) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C5) computation 的类型为 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C6) shape(results...) = shape(operands...)
  • (C7) element_type(results...) = E

示例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

更多示例

all_to_all

语义

all_to_all

在 StableHLO 进程网格中的每个进程组内,沿 split_dimensionoperands 张量的值拆分为多个部分,将拆分后的部分分散到各个进程之间,沿 concat_dimension 串联分散的部分,并生成 results 张量。该操作会将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0,则为 cross_replica(replica_groups)
  • 如果 channel_id > 0,则设为 cross_partition(replica_groups)

之后,在每个 process_group 中:

  • process_group 中的所有 sendersplit_parts...@sender = split(operands...@sender, split_count, split_dimension)
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group],其中 receiver_index = process_group.index(receiver)
  • results...@process = concatenate(scattered_parts...@process, concat_dimension)

输入

标签 名称 类型 限制条件
(I1) operands 张量数量可变或每个张量量化张量 (C1-C3)、(C9)
(I2) split_dimension 类型为 si64 的常量 (C1)、(C2)、(C9)
(I3) concat_dimension 类型为 si64 的常量 (C3)、(C9)
(I4) split_count si64 类型的常量 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups si64 类型的二维张量常量 (C5-C8)
(I6) channel_id si64 类型的常量

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C9)

限制条件

  • (C1) 0 <= split_dimension < rank(operands...)
  • (C2) dim(operands..., split_dimension) % split_count = 0
  • (C3) 0 <= concat_dimension < rank(operands...)
  • (C4) 0 < split_count
  • (C5) is_unique(replica_groups)
  • (C6) size(replica_groups) 的定义如下:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C7) 0 <= replica_groups < size(replica_groups)
  • (C8) dim(replica_groups, 1) = split_count
  • (C9) type(results...) = type(operands...),但如果 split_dimension != concat_dimension
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count

示例

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

更多示例

语义

对两个张量 lhsrhs 执行元素级 AND 运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑与。
  • 对于整数:按位 AND。

输入

标签 名称 类型 限制条件
(I1) lhs 布尔值或整数类型的张量 (C1)
(I2) rhs 布尔值或整数类型的张量 (C1)

输出

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

更多示例

atan2

语义

lhsrhs 张量执行按元素 atan2 运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 atan2
  • 对于复数:复数 atan2。
  • 对于量化类型:dequantize_op_quantize(atan2, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 浮点型、复杂类型或每张量量化张量 (C1)
(I2) rhs 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

更多示例

batch_norm_grad

语义

计算从 grad_output 反向传播的 batch_norm_training 的多个输入的梯度,并生成 grad_operandgrad_scalegrad_offset 张量。更正式地,此操作可以使用 Python 语法表示为对现有 StableHLO 操作的分解,如下所示:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

对于量化类型,执行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1-C3)、(C5)
(I2) scale 浮点或每张量量化类型的一维张量 (C2)、(C4)、(C5)
(I3) mean 浮点类型或按张量量化类型的 1 维张量 (C2)、(C4)
(I4) variance 浮点类型或按张量量化类型的 1 维张量 (C2)、(C4)
(I5) grad_output 浮点类型的张量或每个张量量化张量 (C2)、(C3)
(I6) epsilon 类型为 f32 的常量
(I7) feature_index si64 类型的常量 (C1)、(C5)

输出

名称 类型 限制条件
grad_operand 浮点类型的张量或每个张量量化张量 (C2)、(C3)
grad_scale 浮点类型或按张量量化类型的 1 维张量 (C2)、(C4)
grad_offset 浮点或每张量量化类型的一维张量 (C2)、(C4)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offset 具有相同的 baseline_element_type
  • (C3) operandgrad_outputgrad_operand 具有相同的形状。
  • (C4) scalemeanvariancegrad_scalegrad_offset 具有相同的形状。
  • (C5) size(scale) = dim(operand, feature_index)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

语义

feature_index 维度以外的所有维度中的 operand 张量进行归一化,并生成 result 张量。更正式地,此运算可以使用 Python 语法表示为对现有 StableHLO 运算的分解,如下所示:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

对于量化类型,请执行 dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1-C7)
(I2) scale 浮点类型或按张量量化类型的 1 维张量 (C2)、(C3)
(I3) offset 浮点类型或按张量量化类型的 1 维张量 (C2)、(C4)
(I4) mean 浮点类型或按张量量化类型的 1 维张量 (C5)
(I5) variance 浮点类型或按张量量化类型的 1 维张量 (C2)、(C6)
(I6) epsilon 类型为 f32 的常量
(I7) feature_index 类型为 si64 的常量 (C1)、(C3-C6)

输出

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C2)、(C7)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetmeanvarianceresult 具有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(mean) = dim(operand, feature_index)
  • (C6) size(variance) = dim(operand, feature_index)
  • (C7) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

语义

计算除 feature_index 维度以外的所有维度的平均和方差,并对 operand 张量进行归一化,生成 outputbatch_meanbatch_var 张量。更正式地说,此操作可以使用 Python 语法表示为对现有 StableHLO 操作的分解,如下所示:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

对于量化类型,执行 dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)
(I2) scale 量化的浮点数或每张量的一维张量 (C2)、(C3)
(I3) offset 浮点值或按张量量化的 1 维张量 (C2)、(C4)
(I4) epsilon f32 类型的常量 (C1)、(C3-C6)
(I5) feature_index 类型为 si64 的常量 (C1)、(C3-C6)

输出

名称 类型 限制条件
output 浮点类型的张量或每张量量化张量 (C7)
batch_mean 浮点值或按张量量化的 1 维张量 (C2)、(C5)
batch_var 浮点值或按张量量化的 1 维张量 (C2)、(C6)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetbatch_meanbatch_varoutput 具有相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(batch_mean) = dim(operand, feature_index)
  • (C6) size(batch_var) = dim(operand, feature_index)
  • (C7) baseline_type(output) = baseline_type(operand)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

语义

operand 张量执行 Bitcast 操作,并生成一个 result 张量,其中使用 result 张量的类型重新解释整个 operand 张量的位。

更正式地说,假设有 E = element_type(operand)E' = element_type(result)R = rank(operand)

  • 如果为 num_bits(E') < num_bits(E),则为 bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • 如果为 num_bits(E') > num_bits(E),则为 bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • 如果为 num_bits(E') = num_bits(E),则为 bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits 会返回给定值的内存表示形式,并且其行为由实现定义,因为张量的确切表示形式由实现定义,元素类型的确切表示形式也由实现定义。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C2)

输出

名称 类型 限制条件
result 张量或量化张量 (C1-C2)

限制条件

  • (C1) 假设 E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand)
    • 如果为 num_bits(E') = num_bits(E),则为 shape(result) = shape(operand)
    • 如果 num_bits(E') < num_bits(E)
    • rank(result) = R + 1
    • 针对所有0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(result, R) * num_bits(E') = num_bits(E)
    • 如果 num_bits(E') > num_bits(E)
    • rank(result) = R - 1
    • dim(result, i) = dim(operand, i)(适用于所有 0 <= i < R)。
    • dim(operand, R - 1) * num_bits(E) = num_bits(E')
  • (C2) 如果 is_complex(operand) or is_complex(result),则 is_complex(operand) and is_complex(result)

示例

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

更多示例

broadcast_in_dim

语义

通过复制 operand 张量中的数据来扩展输入张量的维度和/或秩,并生成 result 张量。更正式地,result[result_index] = operand[operand_index],其中对于 axes(operand) 中的所有 d

  • 如果 dim(operand, d) = 1,则设为 operand_index[d] = 0
  • 否则为 operand_index[d] = result_index[broadcast_dimensions[d]]

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C2)、(C5-C6)
(I2) broadcast_dimensions 类型为 si64 的一维张量常量 (C2-C6)

输出

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3)、(C5-C6)

限制条件

  • (C1) element_type(result) 由以下公式给出:
    • 如果 !is_per_axis_quantized(operand),则设为 element_type(operand)
    • element_type(operand),但 quantization_dimension(operand)scales(operand)zero_points(operand) 可能与 quantization_dimension(result)scales(result)zero_points(result) 不同。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 对于 axes(operand) 中的所有 d
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result)
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果为 dim(operand, quantization_dimension(operand)) = 1,则为 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))

示例

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多示例

场景

语义

根据 index 的值,通过执行 branches 中的恰好一个函数来生成输出。更正式地说,result = selected_branch(),其中:

  • 如果 0 <= index < size(branches),则设为 selected_branch = branches[index]
  • 否则为 selected_branch = branches[-1]

输入

标签 名称 类型 限制条件
(I1) index 类型为 si32 的 0 维张量
(I2) branches 函数的可变数量 (C1-C4)

输出

名称 类型 限制条件
results 可变数量的张量、量化张量或令牌 (C4)

限制条件

  • (C1) 0 < size(branches)
  • (C2) input_types(branches...) = []
  • (C3) same(output_types(branches...))
  • (C4) type(results...) = output_types(branches[0])

示例

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

更多示例

cbrt

语义

operand 张量执行元素级立方根运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 rootn(x, 3)
  • 对于复数:复数立方根。
  • 对于量化类型:dequantize_op_quantize(cbrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

更多示例

ceil

语义

operand 张量执行元素级 ceil,并生成 result 张量。 实现 IEEE-754 规范中的 roundToIntegralTowardPositive 运算。对于量化类型,执行 dequantize_op_quantize(ceil, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点类型的张量或每个张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

更多示例

Cholesky

语义

计算一批矩阵的 Cholesky 分解。

更正式地说,对于 index_space(result) 中的所有 iresult[i0, ..., iR-3, :, :]a[i0, ..., iR-3, :, :] 的 Cholesky 分解,形式为下三角矩阵(如果 lowertrue)或上三角矩阵(如果 lowerfalse)。对角三角形(即严格的上三角形或严格的下三角形)中的输出值由实现定义。

如果存在 i,其中输入矩阵不是埃尔米特正定矩阵,则行为未定义。

对于量化类型,执行 dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))

输入

标签 名称 类型 限制条件
(I1) a 浮点或复杂类型的张量,或每个张量的量化张量 (C1-C3)
(I2) lower i1 类型的 0 维张量常量

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(a) = baseline_type(result)
  • (C2) 2 <= rank(a)
  • (C3) dim(a, -2) = dim(a, -1)

示例

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

限制取值范围

语义

operand 张量的每个元素限制在最小值和最大值之间,并生成 result 张量。更正式地,result[result_index] = minimum(maximum(operand[result_index], min_element), max_element),其中 min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index]。对于量化类型,执行 dequantize_op_quantize(clamp, min, operand, max, type(result))

对复杂数强制执行排序会涉及令人意外的语义,因此我们计划在未来移除对此运算的复杂数支持 (#560)。

输入

标签 名称 类型 限制条件
(I1) min 张量或按张量量化的张量 (C1)、(C3)
(I2) operand 张量或按张量量化的张量 (C1 - C4)
(I3) max 张量或按张量量化的张量 (C2)、(C3)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C4)

限制条件

  • (C1) rank(min) = 0 or shape(min) = shape(operand)
  • (C2) rank(max) = 0 or shape(max) = shape(operand)
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4) baseline_type(operand) = baseline_type(result)

示例

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

更多示例

collective_broadcast

语义

在 StableHLO 进程网格中的每个进程组内,将 operand 张量的值从源进程发送到目标进程,并生成 result 张量。

该操作会将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0,则为 cross_replica(replica_groups)
  • 如果 channel_id > 0,则设为 cross_partition(replica_groups)

之后,result@process 可按下式给出:

  • 如果存在 i 且进程位于 process_groups[i] 中,则为 operand@process_groups[i, 0]
  • 否则为 broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C3)
(I2) replica_groups 类型为 si64 的一维张量常量的可变数量 (C1)、(C2)
(I3) channel_id si64 类型的常量

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C3)

限制条件

  • (C1) is_unique(replica_groups)
  • (C2) 0 <= replica_groups < N,其中 N 定义为:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C3) type(result) = type(operand)

示例

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

语义

在 StableHLO 进程网格中的每个进程组内,将 operand 张量的值从源进程发送到目标进程,并生成 result 张量。

该操作会将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0,则设为 cross_replica(source_target_pairs)
  • 如果 channel_id > 0,则设为 cross_partition(source_target_pairs)

之后,result@process 可按下式给出:

  • operand@process_groups[i, 0],如果存在 iprocess_groups[i, 1] = process
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C5)
(I2) source_target_pairs 类型为 si64 的二维张量常量 (C1-C4)
(I3) channel_id si64 类型的常量

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) dim(source_target_pairs, 1) = 2
  • (C2) is_unique(source_target_pairs[:, 0])
  • (C3) is_unique(source_target_pairs[:, 1])
  • (C4) 0 <= source_target_pairs < N,其中 N 的定义如下:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C5) type(result) = type(operand)

示例

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

更多示例

比较

语义

根据 comparison_directioncompare_typelhsrhs 张量执行元素级比较,并生成 result 张量。

comparison_directioncompare_type 的值具有以下语义:

对于布尔值和整数元素类型:

  • EQlhs = rhs
  • NElhs != rhs
  • GElhs >= rhs
  • GTlhs > rhs
  • LElhs <= rhs
  • LTlhs < rhs

对于具有 compare_type = FLOAT 的浮点元素类型,该操作会实现以下 IEEE-754 操作:

  • EQcompareQuietEqual
  • NEcompareQuietNotEqual
  • GEcompareQuietGreaterEqual
  • GTcompareQuietGreater
  • LEcompareQuietLessEqual
  • LTcompareQuietLess

对于包含 compare_type = TOTALORDER 的浮点元素类型,op 会使用 IEEE-754 中的 totalOrdercompareQuietEqual 运算的组合。

对于复杂元素类型,系统会使用提供的 comparison_directioncompare_type(real, imag) 对执行字典顺序比较。对复数施加排序涉及到令人惊讶的语义,因此将来,我们计划在 comparison_directionGEGTLELT 时停止对复数的支持 (#560)。

对于量化类型。执行 dequantize_compare(lhs, rhs, comparison_direction)

输入

标签 名称 类型 限制条件
(I1) lhs 张量或按张量量化的张量 (C1-C3)
(I2) rhs 张量或每张量量化张量 (C1-C2)
(I3) comparison_direction EQNEGEGTLELT 的枚举
(I4) compare_type FLOATTOTALORDERSIGNEDUNSIGNED 的枚举 (C3)

输出

名称 类型 限制条件
result 布尔值类型的张量 (C2)

限制条件

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2) shape(lhs) = shape(rhs) = shape(result)
  • (C3) compare_type 的定义如下:
    • 如果 is_signed_integer(element_type(lhs)),则设为 SIGNED
    • 如果 is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)),则为 UNSIGNED
    • 如果为 is_float(element_type(lhs)),则为 FLOATTOTALORDER
    • 如果 is_complex(element_type(lhs)),则设为 FLOAT

示例

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

更多示例

复杂

语义

对一对实值和虚值(lhsrhs)按元素执行转换,将其转换为复值,并生成 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 类型为 f32f64 的张量 (C1-C3)
(I2) rhs 类型为 f32f64 的张量 (C1)

输出

名称 类型 限制条件
result 复杂类型的张量 (C2)、(C3)

限制条件

  • (C1) type(lhs) = type(rhs)
  • (C2) shape(result) = shape(lhs)
  • (C3) element_type(result) 的类型为 complex<E>,其中 E = element_type(lhs)

示例

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

更多示例

复合型

语义

封装由其他 StableHLO 操作组成(组合)的操作,接受 inputscomposite_attributes 并生成 results。操作的语义由 decomposition 属性实现。composite 运算可以替换为其分解,而无需更改程序语义。如果内嵌分解的分解不能提供相同的操作语义,请优先使用 custom_call

version 字段(默认为 0)用于表示复合项的语义何时发生变化。

输入

标签 名称 类型
(I1) inputs 值的变参数量
(I2) name 类型为 string 的常量
(I3) composite_attributes 属性字典
(I4) decomposition string 类型的常量
(I5) version 类型为 si32 的常量

输出

名称 类型
results 值的变参数量

限制条件

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

示例

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

更多示例

concatenate

语义

沿 dimension 维度串联 inputs,顺序与给定参数相同,并生成 result 张量。更正式地说是 result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1],其中:

  1. id = d0 + ... + dk-1 + kd
  2. d 等于 dimensiond0、... 是 inputsd 个维度大小。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1 - C6)
(I2) dimension si64 类型的常量 (C2)、(C4)、(C6)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C5-C6)

限制条件

  • (C1) same(element_type(inputs...))
  • (C2) same(shape(inputs...))dim(inputs..., dimension) 除外)。
  • (C3) 0 < size(inputs)
  • (C4) 0 <= dimension < rank(inputs[0])
  • (C5) element_type(result) = element_type(inputs[0])
  • (C6) shape(result) = shape(inputs[0]),但以下情况除外:
    • dim(result, dimension) = dim(inputs[0], dimension) + ...

示例

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

更多示例

常量

语义

从常量 value 生成 output 张量。

输入

标签 名称 类型 限制条件
(I1) value 常量 (C1)

输出

名称 类型 限制条件
output 张量或量化张量 (C1)

限制条件

  • (C1) type(value) = type(output)

示例

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

更多示例

转化

语义

operand 张量执行元素类型之间的逐元素转换,并生成 result 张量。

对于boolean-to-any-supported-type,值 false 会转换为零,值 true 会转换为 1。对于任何受支持的类型转换为布尔值,零值会转换为 false,非零值会转换为 true。请参阅下文,了解此机制如何适用于复杂类型。

对于涉及整数转换为整数整数转换为浮点数浮点数转换为浮点数的转换,如果源值可以准确表示为目标类型,则结果值就是该准确表示。否则,行为尚待确定 (#180)。

对于涉及浮点数转换为整数的转换,系统会截断小数部分。如果截断的值无法用目标类型表示,则行为待定 (#180)。

在转换实部和虚部时,涉及复数的转换遵循与浮点数到浮点数转换相同的行为。

对于复数到任意其他类型任何其他类型到复杂转换,源虚值或目的地虚值分别会被忽略。实部转换遵循浮点转换。

原则上,此运算可以表示去量化(从量化张量转换为正则张量)、量化(从常规张量转换为量化张量)和重新量化(量化张量之间的转换),但目前我们有专门的操作 - 对于第一个用例,为 uniform_dequantize,为第二个和第三个用例提供 uniform_quantize。将来,这两个操作可能会合并到 convert (#1576)。

输入

标签 名称 类型 限制条件
(I1) operand 张量 (C1)

输出

名称 类型 限制条件
result 张量 (C1)

限制条件

  • (C1) shape(operand) = shape(result)

示例

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

更多示例

卷积

语义

计算 lhs 窗口与 rhs 切片之间的点积,并生成 result。下图通过一个具体示例展示了如何根据 lhsrhs 计算 result 中的元素。

卷积

更正式地讲,请考虑以下以 lhs 为基础的输入重新构建,以便能够表达 lhs 的窗口:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)

此重新构建使用以下辅助函数:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1],其中 j[d] = i[permutation[d]]

如果为 feature_group_count = 1batch_group_count = 1,则对于 index_space(dim(result, output_spatial_dimensions...)) 中的所有 output_spatial_indexresult[result_shape(:, output_spatial_index, :)] = dot_product,其中:

  • padding_value = constant(0, element_type(lhs))
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。此功能似乎未使用,因此将来我们计划将其移除 (#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])

如果为 feature_group_count > 1

  • lhses = split(lhs, feature_group_count, input_feature_dimension)
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

如果 batch_group_count > 1

  • lhses = split(lhs, batch_group_count, input_batch_dimension)
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

对于量化类型,执行 dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))

对于混合量化类型,执行 hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs)

输入

标签 名称 类型 限制条件
(I1) lhs 张量或按张量量化的张量 (C1)、(C10-C11)、(C14) (C25)、(C27-C28)、(C31-C32)、(C34)
(I2) rhs 张量或量化张量 (C1)、(C14-C16)、(C25)、(C27-C29)、(C31-C34)
(I3) window_strides 类型为 si64 的一维张量常量 (C2-C3)、(C25)
(I4) padding 类型为 si64 的二维张量常量 (C4)、(C25)
(I5) lhs_dilation 类型为 si64 的一维张量常量 (C5-C6)、(C25)
(I6) rhs_dilation 类型为 si64 的一维张量常量 (C7-C8)、(C25)
(I7) window_reversal 类型为 i1 的一维张量常量 (C9)
(I8) input_batch_dimension 类型为 si64 的常量 (C10)、(C13)、(C25)
(I9) input_feature_dimension si64 类型的常量 (C11)、(C13-C14)
(I10) input_spatial_dimensions 类型为 si64 的一维张量常量 (C12)、(C13)、(C25)
(I11) kernel_input_feature_dimension 类型为 si64 的常量 (C14)、(C18)
(I12) kernel_output_feature_dimension 类型为 si64 的常量 (C15-C16)、(C18)、(C25)、(C29)
(I13) kernel_spatial_dimensions si64 类型的一维张量常量 (C17-C18)、(C25)
(I14) output_batch_dimension 类型为 si64 的常量 (C20)、(C25)
(I15) output_feature_dimension si64 类型的常量 (C20)、(C25)、(C30)
(I16) output_spatial_dimensions 类型为 si64 的一维张量常量 (C19-C20), (C25)
(I17) feature_group_count 类型为 si64 的常量 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count 类型为 si64 的常量 (C10)、(C15)、(C22)、(C23)、(C25)
(I19) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C24)

输出

名称 类型 限制条件
result 张量或量化张量 (C25-C28)、(C30)、(C32-34)

限制条件

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 给定 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) 给定 kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 给定 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定义如下:
    • 如果 result_dim = output_batch_dimension,则设为 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,则设为 dim(rhs, kernel_output_feature_dimension)
    • 否则为 num_windows,其中:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果运算使用非量化张量:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果操作使用量化张量:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果 is_per_axis_quantized(result),则 quantization_dimension(result) = output_feature_dimension
    • 如果 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result)

示例

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

更多示例

余弦

语义

operand 张量执行元素级余弦运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 cos
  • 对于复数:复余弦。
  • 对于量化类型:dequantize_op_quantize(cosine, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

更多示例

count_leading_zeros

语义

按元素对 operand 张量中的前导零位数进行计数,并生成 result 张量。

输入

标签 名称 类型 限制条件
(I1) operand 整数类型的张量 (C1)

输出

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

更多示例

custom_call

语义

封装实现定义的操作 call_target_name,该操作接受 inputscalled_computations 并生成 resultshas_side_effectbackend_configapi_version 可用于提供其他实现定义的元数据。

目前,此操作包含相当混乱的元数据集合,这反映了其在 XLA 编译器中的对等操作的自然演变。未来,我们计划统一这些元数据 (#741)。

输入

标签 名称 类型
(I1) inputs 值的变参数量
(I2) call_target_name 类型为 string 的常量
(I3) has_side_effect 类型为 i1 的常量
(I4) backend_config 类型为 string 的常量或属性字典
(I5) api_version si32 类型的常量
(I6) called_computations 类型为 string 的可变数量的常量

输出

名称 类型
results 值的可变数量

示例

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

语义

对被除数 lhs 和除数 rhs 张量执行逐元素除法,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数除法,会生成代数商,并舍弃任何小数部分。
  • 对于浮点数:IEEE-754 中的 division
  • 对于复数:复数除法。
  • 对于量化类型:
    • dequantize_op_quantize(divide, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

更多示例

dot_general

语义

计算 lhs 的 slice 与 rhs 的 slice 之间的点积,并生成 result 张量。

更正式的名称是 result[result_index] = dot_product,其中:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
  • result_batching_index + result_lhs_index + result_rhs_index = result_index,其中 size(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

对于量化类型,执行 dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))

对于混合量化类型,执行 hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs)

precision_config 用于控制加速器后端上计算的速度和准确性。可以是以下任一值(目前,这些枚举值的语义未充分指定,但我们计划在 #755 中解决此问题):

  • DEFAULT:计算速度最快,但与原始数字的近似值最不准确。
  • HIGH:计算速度较慢,但与原始数字的近似值更准确。
  • HIGHEST:计算速度最慢,但与原始数字最接近。

DotAlgorithm 定义了用于实现点运算的算法的主要属性,该算法还定义了精度。如果设置了算法属性字段,则 precision_config 必须为 DEFAULTDotAlgorithms 没有默认值,因为默认参数由实现定义。因此,所有圆点算法字段都可以设置为 None,以指定空圆点算法,该算法将改用 precision_config 值。

DotAlgorithm 字段包括:

  • lhs_precision_typerhs_precision_type,操作的左侧和右侧舍入的精度。精度类型与输入和输出的存储类型无关。
  • accumulation_type:用于累加的精度。
  • lhs_component_countrhs_component_countnum_primitive_operations 适用于执行将左侧和/或右侧分解为多个组件并对这些值执行多个“基元”点积运算的算法 - 通常是为了模拟更高的精度(例如利用 bfloat16 人工智能数据类型进行更高精度计算:bf16_6x tf32_3x 等)。对于不分解的算法,这些值应设置为 1
  • allow_imprecise_accumulation,用于指定是否允许对某些步骤(例如 CUBLASLT_MATMUL_DESC_FAST_ACCUM)进行低精度累加。

DotAlgorithm 属性示例:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

具体支持哪些组合取决于实现。一般来说,我们无法保证 StableHLO 的使用方在每种加速器类型上都支持每种算法。如果给定算法不受支持,则应引发错误,而不是回退到备选算法。StableHLO 验证将尽最大努力进行验证,以防止在任何硬件上不支持的算法。

如需查看一些受支持的算法值,请参阅 xla_data.proto > Algorithm。工单 #2483 包含后端支持的算法创建集中文档的计划。

输入

标签 名称 类型 限制条件
(I1) lhs 张量或按张量量化的张量 (C5-C6)、(C9-C10)、(C12-C14)、(C17-C18)、(C20)
(I2) rhs 张量或量化张量 (C7-C10)、(C12-C20)
(I3) lhs_batching_dimensions 类型为 si64 的一维张量常量 (C1)、(C3)、(C5)、(C9)、(C12)
(I4) rhs_batching_dimensions 类型为 si64 的一维张量常量 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions 类型为 si64 的一维张量常量 (C2)、(C3)、(C6)、(C10)
(I6) rhs_contracting_dimensions 类型为 si64 的一维张量常量 (C2)、(C4)、(C8)、(C10)、(C16)
(I7) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C11)、(C21)
(I8) lhs_precision_type FloatType 或 TensorFloat32 (C21)
(I9) rhs_precision_type FloatType 或 TensorFloat32 (C21)
(I10) accumulation_type FloatType 或 TensorFloat32 (C21)
(I11) lhs_component_count 类型为 si32 的常量 (C21)、(C22)
(I12) rhs_component_count 类型为 si32 的常量 (C21)、(C23)
(I13) num_primitive_operations 类型为 si32 的常量 (C21)、(C24)
(I14) allow_imprecise_accumulation bool 类型的常量 (C21)

输出

名称 类型 限制条件
result 张量或量化张量 (C12)、(C14)、(C18-C20)

限制条件

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs)
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs)
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11) size(precision_config) = 2
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • 如果运算使用非量化张量:
    • (C13) element_type(lhs) = element_type(rhs)
  • 如果操作使用量化张量:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C15) zero_points(rhs) = 0
    • (C16) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs) 不在 rhs_contracting_dimensions 中。
    • 如果 is_quantized(lhs)
    • (C17) storage_type(lhs) = storage_type(rhs)
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C19) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result)
  • 如果 !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
    • (C21) precision_config... = DEFAULT
    • (C22) 0 < lhs_component_count
    • (C23) 0 < rhs_component_count
    • (C24) 0 < num_primitive_operations

示例

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

更多示例

dynamic_broadcast_in_dim

语义

此运算的功能与 broadcast_in_dim 运算相同,但结果形状是通过 output_dimensions 动态指定的。

该运算还接受可选属性 known_expanding_dimensionsknown_nonexpanding_dimensions,以表达有关尺寸展开行为的静态知识。如果未指定,则假定所有维度都可能扩展。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C2)、(C5-C6)、(C9)
(I2) output_dimensions 整数类型的一维张量 (C7)
(I3) broadcast_dimensions 整数类型的一维常量张量 (C2-C6)
(I4) known_expanding_dimensions 整数类型的一维常量张量 (C8-C9)
(I5) known_nonexpanding_dimensions 整数类型的一维常量张量 (C8-C9)

输出

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3)、(C5-C7)

限制条件

  • (C1) element_type(result) 由以下公式给出:
    • 如果 !is_per_axis_quantized(operand),则设为 element_type(operand)
    • element_type(operand),但 quantization_dimension(operand)scales(operand)zero_points(operand) 可能与 quantization_dimension(result)scales(result)zero_points(result) 不同。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 对于 axes(operand) 中的所有 d
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result)
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果为 dim(operand, quantization_dimension(operand)) = 1,则为 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
  • (C7) size(output_dimensions) = rank(result)
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
  • (C9) 0 <= known_expanding_dimensions < rank(operand)
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand)

示例

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多示例

dynamic_conv

语义

此操作在功能上与卷积操作相同,但填充是通过 padding 动态指定的。

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)、(C10-C11)、(C14) (C25)、(C26-C27)、(C30-C31)、(C33)
(I2) rhs 张量或量化张量 (C1)、(C14-C16)、(C26-C28)、(C30-C33)
(I3) padding 整数类型的二维张量 (C4)
(I4) window_strides 类型为 si64 的一维张量常量 (C2-C3)
(I5) lhs_dilation 类型为 si64 的一维张量常量 (C5-C6)
(I6) rhs_dilation 类型为 si64 的一维张量常量 (C7-C8)
(I7) window_reversal 类型为 i1 的一维张量常量 (C9)
(I8) input_batch_dimension si64 类型的常量 (C10)、(C13)
(I9) input_feature_dimension si64 类型的常量 (C11)、(C13-C14)
(I10) input_spatial_dimensions si64 类型的一维张量常量 (C12)、(C13)
(I11) kernel_input_feature_dimension 类型为 si64 的常量 (C14)、(C18)
(I12) kernel_output_feature_dimension 类型为 si64 的常量 (C15-C16)、(C18)、(C28)
(I13) kernel_spatial_dimensions si64 类型的一维张量常量 (C17-C18)
(I14) output_batch_dimension 类型为 si64 的常量 (C20)
(I15) output_feature_dimension si64 类型的常量 (C20)、(C29)
(I16) output_spatial_dimensions 类型为 si64 的一维张量常量 (C19-C20)
(I17) feature_group_count 类型为 si64 的常量 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count 类型为 si64 的常量 (C10)、(C15)、(C22)、(C23)
(I19) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C24)

输出

名称 类型 限制条件
result 张量或量化张量 (C25-C27)、(C29)、(C31-C33)

限制条件

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 给定 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) 给定 kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 给定 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定义如下:
    • 如果 result_dim = output_batch_dimension,则设为 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,则设为 dim(rhs, kernel_output_feature_dimension)
    • 否则为 num_windows,其中:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果运算使用非量化张量:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果操作使用量化张量:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果 is_per_axis_quantized(result),则 quantization_dimension(result) = output_feature_dimension
    • 如果 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result)

示例

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

更多示例

dynamic_gather

语义

此操作在功能上与 gather 操作相同,其中 slice_sizes 以值的形式动态指定。

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C7)、(C10-C12)、(C14)
(I2) start_indices 整数类型的张量 (C2)、(C3)、(C13)
(I3) slice_sizes 整数类型的一维张量 (C8)、(C11-C13)
(I4) offset_dims 类型为 si64 的一维张量常量 (C1)、(C4-C5)、(C13)
(I5) collapsed_slice_dims si64 类型的一维张量常量 (C1)、(C6-C8)、(C13)
(I6) start_index_map 类型为 si64 的一维张量常量 (C3)、(C9)、(C10)
(I7) index_vector_dim 类型为 si64 的常量 (C2)、(C3)、(C13)
(I8) indices_are_sorted 类型为 i1 的常量

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C5)、(C13-C14)

限制条件

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7) 0 <= collapsed_slice_dims < rank(operand)
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1
  • (C9) is_unique(start_index_map)
  • (C10) 0 <= start_index_map < rank(operand)
  • (C11) size(slice_sizes) = rank(operand)
  • (C12) 0 <= slice_sizes <= shape(operand)
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes),其中:
    • batch_dim_sizes = shape(start_indices),但不包含与 index_vector_dim 对应的 start_indices 的维度大小。
    • offset_dim_sizes = shape(slice_sizes),但不包括与 collapsed_slice_dims 对应的 slice_sizes 中的尺寸尺寸。
    • combine 会将 batch_dim_sizes 置于 batch_dims 对应的轴,而 offset_dim_sizes 则位于 offset_dims 对应的轴。
  • (C14) element_type(operand) = element_type(result)

示例

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

更多示例

dynamic_iota

语义

此运算在功能上与 iota 运算相同,但结果形状是通过 output_shape 动态指定的。

输入

标签 名称 类型 限制条件
(I1) output_shape 整数类型的一维张量 (C1)、(C2)
(I2) iota_dimension si64 (C1)

输出

名称 类型 限制条件
result 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C2)

限制条件

  • (C1) 0 <= iota_dimension < size(output_shape)
  • (C2) rank(result) = size(output_shape)

示例

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

更多示例

dynamic_pad

语义

此操作在功能上与填充操作相同,但将 edge_padding_lowedge_padding_highinterior_padding 动态指定为值。

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C2)、(C4)
(I2) padding_value 0 维张量或每张量量化张量 (C1)
(I3) edge_padding_low 整数类型的一维张量 (C1)、(C4)
(I4) edge_padding_high 整数类型的一维张量 (C1)、(C4)
(I5) interior_padding 整数类型的一维张量 (C2-C4)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C3-C6)

限制条件

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

示例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多示例

dynamic_reshape

语义

此操作在功能上与重塑操作相同,但结果形状是通过 output_shape 动态指定的。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C3)
(I2) output_shape 整数类型的一维张量 (C4)

输出

名称 类型 限制条件
result 张量或量化张量 (C1 - C4)

限制条件

  • (C1) element_type(result) 由以下公式给出:
    • 如果 !is_per_axis_quantized(operand),则设为 element_type(operand)
    • element_type(operand),但 quantization_dimension(operand)quantization_dimension(result) 可能会有所不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand)
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
  • (C4) size(output_shape) = rank(result)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

更多示例

dynamic_slice

语义

使用动态计算的起始索引从 operand 中提取一个 slice,并生成 result 张量。start_indices 包含每个可能调整的维度的 slice 的起始索引,slice_sizes 包含每个维度的 slice 的大小。更正式地,result[result_index] = operand[operand_index],其中:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
  • operand_index = adjusted_start_indices + result_index

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C2)、(C4)
(I2) start_indices 整数类型的 0 维张量的可变数量 (C2)、(C3)
(I3) slice_sizes 类型为 si64 的一维张量常量 (C2)、(C4)、(C5)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C1)、(C5)

限制条件

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3) same(type(start_indices...))
  • (C4) 0 <= slice_sizes <= shape(operand)
  • (C5) shape(result) = slice_sizes

示例

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

更多示例

dynamic_update_slice

语义

生成一个 result 张量,该张量等于 operand 张量,但从 start_indices 开始的切片将更新为使用 update 中的值。更正式地说,result[result_index] 的定义如下:

  • update[update_index],如果 0 <= update_index < shape(update),则:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
    • update_index = result_index - adjusted_start_indices
  • 否则为 operand[result_index]

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1-C4)、(C6)
(I2) update 张量或每张量量化张量 (C2)、(C3)、(C6)
(I3) start_indices 整数类型的 0 维张量的可变数量 (C4)、(C5)

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) type(operand) = type(result)
  • (C2) element_type(update) = element_type(operand)
  • (C3) rank(update) = rank(operand)
  • (C4) size(start_indices) = rank(operand)
  • (C5) same(type(start_indices...))
  • (C6) 0 <= shape(update) <= shape(operand)

示例

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

更多示例

指数函数

语义

operand 张量执行元素级指数运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 exp
  • 对于复数:复数指数。
  • 对于量化类型:dequantize_op_quantize(exponential, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

更多示例

exponential_minus_one

语义

operand 张量执行元素级指数减一个运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 expm1
  • 对于复数:复数指数减一。
  • 对于量化类型:dequantize_op_quantize(exponential_minus_one, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

更多示例

fft

语义

对实数和复数输入/输出执行正向和反向傅里叶转换。

fft_type 是以下值之一:

  • FFT:转发复杂到复杂 FFT。
  • IFFT:复杂数到复杂数的逆 FFT。
  • RFFT:正向实数到复数 FFT。
  • IRFFT:实数到复数的反函数 FFT(即取复数,返回实数)。

更正式地说,假设函数 fft 接受复杂类型的 1 维张量作为输入,生成与输出相同类型的 1 维张量,并计算离散傅里叶变换:

对于 fft_type = FFTresult 定义为一系列 L 计算的最终结果,其中 L = size(fft_length)。例如,对于 L = 3

  • result1[i0, ..., :] = fft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

此外,假设函数 ifft 具有相同的类型签名并计算 fft 的反函数:

对于 fft_type = IFFTresult 定义为 fft_type = FFT 计算的逆运算。例如,对于 L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :])

此外,假设函数 rfft 接受浮点类型的一维张量,并生成具有相同浮点语义的复杂类型的一维张量,其运作方式如下:

  • rfft(real_operand) = truncated_result,其中
  • complex_operand... = (real_operand..., 0.0)
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]

(在针对实际运算数计算离散福里叶转换时,结果的前 N/2 + 1 个元素会明确定义结果的其余部分,因此 rfft 的结果会被截断,以避免计算冗余元素)。

对于 fft_type = RFFTresult 定义为一系列 L 计算的最终结果,其中 L = size(fft_length)。例如,对于 L = 3

  • result1[i0, ..., :] = rfft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

最后,假设函数 irfft 具有相同的类型签名并计算 rfft 的反函数:

对于 fft_type = IRFFTresult 定义为 fft_type = RFFT 计算的逆运算。例如,对于 L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :])

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型或复杂类型的张量 (C1)、(C2)、(C4)、(C5)
(I2) fft_type FFTIFFTRFFTIRFFT 的枚举 (C2)、(C5)
(I3) fft_length si64 类型的一维张量常量 (C1)、(C3)、(C4)

输出

名称 类型 限制条件
result 浮点类型或复杂类型的张量 (C2)、(C4)、(C5)

限制条件

  • (C1) size(fft_length) <= rank(operand)
  • (C2) operandresult 元素类型之间的关系因以下原因而异:
    • 如果 fft_type = FFTelement_type(operand)element_type(result) 具有相同的复杂类型。
    • 如果 fft_type = IFFTelement_type(operand)element_type(result) 具有相同的复杂类型。
    • 如果为 fft_type = RFFT,则 element_type(operand) 为浮点类型,element_type(result) 为具有相同浮点语义的复杂类型。
    • 如果为 fft_type = IRFFT,则 element_type(operand) 为复杂类型,element_type(result) 为具有相同浮点语义的浮点类型。
  • (C3) 1 <= size(fft_length) <= 3
  • (C4) 如果 operandresult 中包含浮点类型的张量 real,则 shape(real)[-size(fft_length):] = fft_length
  • (C5) shape(result) = shape(operand),但以下情况除外:
    • 如果为 fft_type = RFFT,则为 dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • 如果为 fft_type = IRFFT,则为 dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

示例

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

语义

operand 张量执行元素级地板函数,并生成 result 张量。实现 IEEE-754 规范中的 roundToIntegralTowardNegative 运算。对于量化类型,执行 dequantize_op_quantize(floor, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点类型的张量或每个张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

更多示例

收集

语义

start_indices 中指定的偏移量收集 operand 张量的切片,并生成 result 张量。

下图通过一个具体示例展示了 result 中的元素如何映射到 operand 中的元素。该图表选择了一些示例 result 索引,并详细说明了它们对应的 operand 索引。

收集

更正式地说,result[result_index] = operand[operand_index],其中:

  • batch_dims = [d for d in axes(result) and d not in offset_dims]
  • batch_index = result_index[batch_dims...]
  • start_index 的定义如下:
    • start_indices[bi0, ..., :, ..., biN],其中 bibatch_index 中的各个元素,: 插入到 index_vector_dim 索引(如果 index_vector_dim < rank(start_indices))处。
    • 否则为 [start_indices[batch_index]]
  • 对于 axes(operand) 中的 d_operand
    • 如果 d_operand = start_index_map[d_start],则设为 full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
    • 否则为 full_start_index[d_operand] = 0
  • 对于 axes(operand) 中的 d_operand
    • 如果 d_operand = operand_batching_dims[i_batching]d_start = start_indices_batching_dims[i_batching],则为 full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
    • 否则为 full_batching_index[d_operand] = 0
  • offset_index = result_index[offset_dims...]
  • full_offset_index = [oi0, ..., 0, ..., oiN],其中 oioffset_index 中的各个元素,0 会插入 collapsed_slice_dimsoperand_batching_dims 中的索引。
  • operand_index = full_start_index + full_batching_index + full_offset_index

如果 indices_are_sortedtrue,则实现可以假定 start_indices 是按 start_index_map 排序的,否则行为将属于未定义行为。更正式地说,对于来自 indices(result) 的所有 i1 < i2,该属性为 full_start_index(i1) <= full_start_index(i2)

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C8)、(C11)、(C17)、(C19-C21)、(C23)
(I2) start_indices 整数类型的张量 (C2-C3)、(C14)、(C17)、(C22)
(I3) offset_dims 类型为 si64 的一维张量常量 (C1)、(C4-C5)、(C22)
(I4) collapsed_slice_dims 类型为 si64 的一维张量常量 (C1)、(C6-C9)、(C22)
(I5) operand_batching_dims si64 类型的一维张量常量 (C1)、(C6)、(C10-C12)、(C16-C18)、(C22)
(I6) start_indices_batching_dims 类型为 si64 的一维张量常量 (C13-C17)
(I7) start_index_map 类型为 si64 的一维张量常量 (C3)、(C18-C19)
(I8) index_vector_dim si64 类型的常量 (C2-C3)、(C15)、(C22)
(I9) slice_sizes si64 类型的一维张量常量 (C9)、(C12)、(C20-C22)
(I10) indices_are_sorted 类型为 i1 的常量

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C5)、(C22-C23)

限制条件

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims)
  • (C8) 0 <= collapsed_slice_dims < rank(operand)
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1
  • (C10) is_sorted(operand_batching_dims)
  • (C11) 0 <= operand_batching_dims < rank(operand)
  • (C12) slice_sizes[operand_batching_dims...] <= 1
  • (C13) is_unique(start_indices_batching_dims)
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices)
  • (C15) index_vector_dim not in start_indices_batching_dims
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims)
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims))
  • (C19) 0 <= start_index_map < rank(operand)
  • (C20) size(slice_sizes) = rank(operand)
  • (C21) 0 <= slice_sizes <= shape(operand)
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes),其中:
    • batch_dim_sizes = shape(start_indices),但不包含与 index_vector_dim 对应的 start_indices 的维度大小。
    • offset_dim_sizes = slice_sizes,但不包含与 collapsed_slice_dimsoperand_batching_dims 对应的 slice_sizes 中的尺寸维度。
    • combine 会将 batch_dim_sizes 置于 batch_dims 对应的轴,而 offset_dim_sizes 则位于 offset_dims 对应的轴。
  • (C23) element_type(operand) = element_type(result)

示例

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

更多示例

get_dimension_size

语义

生成 operand 的给定 dimension 的大小。更正式地说,result = dim(operand, dimension)。语义仅与类型的形状组件相关。元素类型可以是任何类型。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1)
(I2) dimension si64 类型的常量 (C1)

输出

名称 类型
result 类型为 si32 的 0 维张量

限制条件

  • (C1) 0 <= dimension < rank(operand)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

更多示例

get_tuple_element

语义

提取 operand 元组的 index 位置的元素,并生成 result。更正式地说,result = operand[index]

输入

标签 名称 类型 限制条件
(I1) operand tuple (C1)、(C2)
(I2) index 类型为 si32 的常量 (C1)、(C2)

输出

名称 类型 限制条件
result 任何受支持的类型 (C2)

限制条件

  • (C1) 0 <= index < size(operand)
  • (C2) type(result) = tuple_element_types(operand)[index]

示例

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

更多示例

if

语义

根据 pred 的值,通过执行 true_branchfalse_branch 中的恰好一个函数来生成输出。更正式地说,result = pred ? true_branch() : false_branch()

输入

标签 名称 类型 限制条件
(I1) pred i1 类型的 0 维张量
(I2) true_branch 函数 (C1-C3)
(I3) false_branch 函数 (C1)、(C2)

输出

名称 类型 限制条件
results 可变数量的张量、量化张量或令牌 (C3)

限制条件

  • (C1) input_types(true_branch) = input_types(false_branch) = []
  • (C2) output_types(true_branch) = output_types(false_branch)
  • (C3) type(results...) = output_types(true_branch)

示例

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

更多示例

imag

语义

operand 中提取元素级虚部并生成 result 张量。更正式地说,对于每个元素 ximag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型或复杂类型的张量 (C1)、(C2)

输出

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 的定义如下:
    • 如果 is_complex(operand),则设为 complex_element_type(element_type(operand))
    • 否则为 element_type(operand)

示例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

更多示例

infeed

语义

从 infeed 读取数据并生成 results

infeed_config 的语义由实现定义。

results 由位于前面的载荷值和位于后面的令牌组成。将来,我们计划将载荷和令牌拆分为两个单独的输出,以提高清晰度 (#670)。

输入

标签 名称 类型
(I1) token token
(I2) infeed_config 类型为 string 的常量

输出

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C1 - C3)

限制条件

  • (C1) 0 < size(results)
  • (C2) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C3) is_token(type(results[-1]))

示例

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

更多示例

iota

语义

使用从 iota_dimension 维度沿着 0 开始的升序值填充 output 张量。更正式地说,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))

输入

标签 名称 类型 限制条件
(I1) iota_dimension si64 (C1)

输出

名称 类型 限制条件
output 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) 0 <= iota_dimension < rank(output)

示例

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

更多示例

is_finite

语义

按元素检查 x 中的值是否有限(即既不是 +Inf、-Inf 也不是 NaN),并生成 y 张量。实现 IEEE-754 规范中的 isFinite 运算。对于量化类型,结果始终为 true

输入

标签 名称 类型 限制条件
(I1) x 浮点类型的张量或每个张量量化张量 (C1)

输出

名称 类型 限制条件
y 布尔值类型的张量 (C1)

限制条件

  • (C1) shape(x) = shape(y)

示例

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

更多示例

log

语义

operand 张量执行元素级对数运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 log
  • 对于复数:复数对数。
  • 对于量化类型:dequantize_op_quantize(log, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

更多示例

log_plus_one

语义

operand 张量执行元素级对数加一操作,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 logp1
  • 对于复数:复数对数加 1。
  • 对于量化类型:dequantize_op_quantize(log_plus_one, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

更多示例

物流

语义

operand 张量执行元素级逻辑运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 division(1, addition(1, exp(-x)))
  • 对于复数:复杂逻辑函数。
  • 对于量化类型:dequantize_op_quantize(logistic, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

更多示例

地图

语义

沿 dimensions 将映射函数 computation 应用于 inputs,并生成 result 张量。

更正式地说,result[result_index] = computation(inputs...[result_index])

输入

标签 名称 类型 限制条件
(I1) inputs 张量数量可变或每个张量量化张量 (C1-C4)
(I2) dimensions 类型为 si64 的一维张量常量 (C3)
(I3) computation 函数 (C4)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C1)、(C4)

限制条件

  • (C1) shape(inputs...) = shape(result)
  • (C2) 0 < size(inputs) = N
  • (C3) dimensions = range(rank(inputs[0]))
  • (C4) computation 的类型为 (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>,其中 Ei = element_type(inputs[i])E' = element_type(result)

示例

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

更多示例

最大值

语义

对张量 lhsrhs 执行元素级最大值运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 OR。
  • 对于整数:整数上限。
  • 对于浮点数:IEEE-754 中的 maximum
  • 对于复数:(real, imaginary) 对的字典序列最大值。 对复杂数强制执行排序会涉及令人意外的语义,因此我们计划在未来移除对此运算的复杂数支持 (#560)。
  • 对于量化类型:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或按张量量化的张量 (C1)

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

更多示例

最小值

语义

对张量 lhsrhs 执行元素级最小值运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑与。
  • 对于整数:最小值整数。
  • 对于浮点数:IEEE-754 中的 minimum
  • 对于复数:(real, imaginary) 对的字典序最小值。 对复杂数强制执行排序会涉及令人意外的语义,因此我们计划在未来移除对此运算的复杂数支持 (#560)。
  • 对于量化类型:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或按张量量化的张量 (C1)

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

更多示例

相乘

语义

对两个张量 lhsrhs 执行元素级乘法,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑与。
  • 对于整数:整数乘法。
  • 对于浮点数:IEEE-754 中的 multiplication
  • 对于复数:复数乘法。
  • 对于量化类型:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或按张量量化的张量 (C1)

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

更多示例

negate

语义

operand 张量执行元素级否定,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于有符号整数:整数取反。
  • 对于无符号整数:按位转换为有符号整数、整数取反、按位转换回无符号整数。
  • 对于浮点数:IEEE-754 中的 negate
  • 对于复数:复数运算。
  • 对于量化类型:dequantize_op_quantize(negate, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

输出

名称 类型 限制条件
result 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

更多示例

语义

对张量 operand 执行元素级 NOT 运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 NOT。
  • 对于整数:按位非。

参数

名称 类型 限制条件
operand 布尔值或整数类型的张量 (C1)

输出

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

更多示例

optimization_barrier

语义

确保在任何依赖于 result 的运算之前执行生成 operand 的操作,并防止编译器转换使操作跨越屏障。除此之外,操作是一种身份,即 result = operand

参数

名称 类型 限制条件
operand 张量、每个张量量化张量或令牌的变参数量 (C1)

输出

名称 类型 限制条件
result 张量、每个张量量化张量或令牌的变参数量 (C1)

限制条件

  • (C1) type(operand...) = type(result...)

示例

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

更多示例

语义

对两个张量 lhsrhs 执行元素级 OR 运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑或。
  • 对于整数:按位或运算。

输入

标签 名称 类型 限制条件
(I1) lhs 整数或布尔值类型的张量 (C1)
(I2) rhs 整数或布尔值类型的张量 (C1)

输出

名称 类型 限制条件
result 整数或布尔值类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

更多示例

馈出

语义

inputs 写入输出并生成 result 令牌。

outfeed_config 的语义由实现定义。

输入

标签 名称 类型
(I1) inputs 可变数量的张量或量化张量
(I2) token token
(I3) outfeed_config 类型为 string 的常量

输出

名称 类型
result token

示例

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多示例

衬垫

语义

使用给定的 padding_value 在张量周围以及张量元素之间添加内边距,从而扩展 operand

edge_padding_lowedge_padding_high 分别指定在每个维度的低端(靠近编号 0)和高端(靠近最高编号)添加的填充量。内边距可以为负数,其中负内边距的绝对值表示要从指定维度移除的元素数量。

interior_padding 用于指定在每个维度中的任意两个元素之间添加的内边距,该值不得为负。内部内边距会先于边缘内边距,因此负边缘内边距会从内部内边距操作数中移除元素。

更正式地说,result[result_index] 的定义如下:

  • 如果 result_index = edge_padding_low + operand_index * (interior_padding + 1),则设为 operand[operand_index]
  • 否则为 padding_value

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C2)、(C4)
(I2) padding_value 0 维张量或每张量量化张量 (C1)
(I3) edge_padding_low 类型为 si64 的一维张量常量 (C1)、(C4)
(I4) edge_padding_high 类型为 si64 的一维张量常量 (C1)、(C4)
(I5) interior_padding 类型为 si64 的一维张量常量 (C2-C4)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C3-C6)

限制条件

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

示例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多示例

partition_id

语义

生成当前进程的 partition_id

输出

名称 类型
result 类型为 ui32 的 0 维张量

示例

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

更多示例

popcnt

语义

operand 张量中设置的位数执行元素级计数,并生成 result 张量。

输入

标签 名称 类型 限制条件
(I1) operand 整数类型的张量 (C1)

输出

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

更多示例

幂数

语义

lhs 张量按元素对 rhs 张量求指数,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数指数运算。
  • 对于浮点数:IEEE-754 中的 pow
  • 对于复数:复数指数运算。
  • 对于量化类型:dequantize_op_quantize(power, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

更多示例

real

语义

按元素从 operand 中提取实部,并生成 result 张量。更正式地说,对于每个元素 xreal(x) = is_complex(x) ? real_part(x) : x

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型或复杂类型的张量 (C1)、(C2)

输出

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 的定义如下:
    • 如果 is_complex(operand),则设为 complex_element_type(element_type(operand))
    • 否则为 element_type(operand)

示例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

更多示例

recv

语义

使用 channel_id 从通道接收数据,并生成 results

如果 is_host_transfertrue,则该操作会从主机传输数据。否则,它会从其他设备传输数据。具体含义由实现定义。此标志与 channel_type 中提供的信息重复,因此将来我们打算只保留其中一个 (#666)。

results 由先出现的载荷值和最后的令牌组成。未来,我们计划将载荷和令牌拆分为两个单独的输出,以提高清晰度 (#670)。

输入

标签 名称 类型 限制条件
(I1) token token (C4)
(I2) channel_id 类型为 si64 的常量
(I3) channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE 的枚举 (C1)
(I4) is_host_transfer i1 类型的常量 (C1)

输出

名称 类型 限制条件
results 可变数量的张量、量化张量或令牌 (C2-C4)

限制条件

  • (C1) channel_type 的定义如下:
    • 如果 is_host_transfer = true,则为 HOST_TO_DEVICE
    • 否则为 DEVICE_TO_DEVICE
  • (C2) 0 < size(results)
  • (C3) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C4) is_token(type(results[-1]))

示例

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

更多示例

reduce

语义

沿着 dimensions 将归约函数 body 应用于 inputsinit_values,并生成 results 张量。

求值顺序由实现定义,这意味着 bodyinit_values 必须构成单元,以确保该运算针对所有实现中的所有输入产生相同的结果。不过,对于许多常见的求值,此条件并不成立。例如,body 的浮点加法和 init_values 的零实际上并不构成单元,因为浮点加法不遵守结合律。

更正式地说,results...[j0, ..., jR-1] = reduce(input_slices_converted),其中:

  • input_slices = inputs...[j0, ..., :, ..., jR-1],其中 :dimensions 处插入。
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
  • reduce(input_slices_converted) = exec(schedule) 表示某个二元树 schedule,其中:
    • exec(node) = body(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是实现定义的全二元树,其有序遍历包括:
    • index_space(input_slices_converted) 中所有 indexinput_slices_converted...[index] 值,按 index 的字典顺序升序排列。
    • 在实现定义的位置与实现定义的 init_values_converted 数量穿插。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1-C4)、(C6)、(C7)
(I2) init_values 0 维张量或每个张量量化张量的可变数量 (C2)、(C3)
(I3) dimensions 类型为 si64 的一维张量常量 (C4)、(C5)、(C7)
(I4) body 函数 (C6)

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C3)、(C7)、(C8)

限制条件

  • (C1) same(shape(inputs...))
  • (C2) element_type(inputs...) = element_type(init_values...)
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C4) 0 <= dimensions < rank(inputs[0])
  • (C5) is_unique(dimensions)
  • (C6) body 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中 is_promotable(element_type(inputs[i]), Ei)
  • (C7) shape(results...) = shape(inputs...),但不包括与 dimensions 对应的 inputs... 的尺寸。
  • (C8) [0,N) 中的所有 i 均为 element_type(results[i]) = Ei

示例

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

更多示例

reduce_precision

语义

使用 exponent_bitsmantissa_bitsoperand 按元素转换为另一个浮点类型,然后再转换回原始浮点类型,并生成 output 张量。

更正式地:

  • 原始值的尾数位会进行更新,以通过 roundToIntegralTiesToEven 语义将原始值四舍五入为可通过 mantissa_bits 表示的最接近的值。
  • 然后,如果 mantissa_bits 小于原始值的小数部分位数,则小数部分位数会截断为 mantissa_bits
  • 然后,如果中间结果的指数位不符合 exponent_bits 提供的范围,中间结果会使用原始符号溢出到无穷大,或者使用原始符号下溢到零。
  • 对于量化类型,执行 dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)
(I2) exponent_bits si32 类型的常量 (C2)
(I3) mantissa_bits si32 类型的常量 (C3)

输出

名称 类型 限制条件
output 浮点类型的张量或每个张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(output)
  • (C2) 1 <= exponent_bits
  • (C3) 0 <= mantissa_bits

示例

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

更多示例

reduce_scatter

语义

reduce_scatter

在 StableHLO 进程网格的每个进程组内,使用 computations 对每个进程的 operand 张量值执行求和,将求和结果沿 scatter_dimension 拆分为多个部分,并将这些拆分部分分散到各个进程之间以生成 result

该操作将 StableHLO 进程网格拆分为 process_groups,其定义如下:

  • 如果 channel_id <= 0 and use_global_device_ids = false,则设为 cross_replica(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = false,则设为 cross_replica_and_partition(replica_groups)
  • 如果 channel_id > 0 and use_global_device_ids = true,则设为 flattened_ids(replica_groups)

之后,在每个 process_group 中:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
  • process_group 中的所有 sender 执行 result@receiver = parts@sender[receiver_index],其中 receiver_index = process_group.index(receiver)

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C2)、(C7)、(C8)
(I2) scatter_dimension 类型为 si64 的常量 (C1)、(C2)、(C8)
(I3) replica_groups si64 类型的二维张量常量 (C3-C5)
(I4) channel_id 类型为 si64 的常量 (C6)
(I5) use_global_device_ids 类型为 i1 的常量 (C6)
(I6) computation 函数 (C7)

输出

名称 类型 限制条件
result 张量或每张量量化张量 (C8-C9)

限制条件

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2) 0 <= scatter_dimension < rank(operand)
  • (C3) is_unique(replica_groups)
  • (C4) size(replica_groups) 的定义如下:
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C5) 0 <= replica_groups < size(replica_groups)
  • (C6) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C7) computation 的类型为 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C8) shape(result) = shape(operand),但以下情况除外:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
  • (C9) element_type(result) = E

示例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

更多示例

reduce_window

语义

将归约函数 body 应用于 inputsinit_values 的窗口,并生成 results

下图通过一个具体示例展示了如何根据 inputs... 计算 results... 中的元素。

reduce_window

更正式地,results...[result_index] = reduce(windows, init_values, axes(inputs...), body)(请参阅reduce),其中:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations)

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2) init_values 0 维张量或每个张量量化张量的可变数量 (C1), (C13)
(I3) window_dimensions 类型为 si64 的一维张量常量 (C4)、(C5)、(C15)
(I4) window_strides si64 类型的一维张量常量 (C6)、(C7)、(C15)
(I5) base_dilations 类型为 si64 的一维张量常量 (C8)、(C9)、(C15)
(I6) window_dilations 类型为 si64 的一维张量常量 (C10)、(C11)、(C15)
(I7) padding 类型为 si64 的二维张量常量 (C12)、(C15)
(I8) body 函数 (C13)

输出

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C1)、(C14-C16)

限制条件

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C2) same(shape(inputs...))
  • (C3) element_type(inputs...) = element_type(init_values...)
  • (C4) size(window_dimensions) = rank(inputs[0])
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(inputs[0])
  • (C7) 0 < window_strides
  • (C8) size(base_dilations) = rank(inputs[0])
  • (C9) 0 < base_dilations
  • (C10) size(window_dilations) = rank(inputs[0])
  • (C11) 0 < window_dilations
  • (C12) shape(padding) = [rank(inputs[0]), 2]
  • (C13) body 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中 is_promotable(element_type(inputs[i]), Ei)
  • (C14) same(shape(results...))
  • (C15) shape(results[0]) = num_windows,其中:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
  • (C16) [0,N) 中的所有 i 均为 element_type(results[i]) = Ei

示例

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

更多示例

余数

语义

对被除数 lhs 和除数 rhs 张量执行元素级余数运算,并生成 result 张量。

更正式地说,结果的符号取自被除数,并且结果的绝对值始终小于除数的绝对值。余数按 lhs - d * rhs 计算,其中 d 的计算公式为:

  • 对于整数:stablehlo.divide(lhs, rhs)
  • 对于浮点数:IEEE-754 中的 division(lhs, rhs),带有舍入属性 roundTowardZero
  • 对于复数:待定 (#997)。
  • 对于量化类型:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result))

对于浮点元素类型,此运算与 IEEE-754 规范中的 remainder 运算相反,其中 d 是与 lhs/rhs 的确切值最接近的整数值,且在出现平分时取偶数值。

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

更多示例

replica_id

语义

生成当前进程的 replica_id

输出

名称 类型
result 类型为 ui32 的 0 维张量

示例

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

更多示例

reshape

语义

operand 张量重塑为 result 张量。从概念上讲,这相当于保持相同的规范表示法,但可能会更改形状,例如从 tensor<2x3xf32> 更改为 tensor<3x2xf32>tensor<6xf32>

更正式地说,result[result_index] = operand[operand_index](其中 result_indexoperand_indexindex_space(result)index_space(operand) 的字典顺序中)具有相同的位置。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C3)

输出

名称 类型 限制条件
result 张量或量化张量 (C1-C3)

限制条件

  • (C1) element_type(result) 由以下公式给出:
    • 如果 !is_per_axis_quantized(operand),则设为 element_type(operand)
    • element_type(operand),但 quantization_dimension(operand)quantization_dimension(result) 可能会有所不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand)
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

更多示例

reverse

语义

沿指定的 dimensions 反转 operand 中元素的顺序,并生成一个 result 张量。更正式地,result[result_index] = operand[operand_index],其中:

  • 如果 dimensions 中包含 d,则设为 operand_index[d] = dim(result, d) - result_index[d] - 1
  • 否则为 operand_index[d] = result_index[d]

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1)、(C3)
(I2) dimensions 类型为 si64 的一维张量常量 (C2)、(C3)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C1)、(C3)

限制条件

  • (C1) type(operand) = type(result)
  • (C2) is_unique(dimensions)
  • (C3) 0 <= dimensions < rank(result)

示例

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

更多示例

rng

语义

使用 rng_distribution 算法生成随机数,并生成给定形状 shaperesult 张量。

如果值为 rng_distribution = UNIFORM,则随机数将按照 [a, b) 区间内的均匀分布生成。如果为 a >= b,则行为将属于未定义行为。

如果为 rng_distribution = NORMAL,则随机数的生成遵循均值为 a 且标准差为 b 的正态分布。如果为 b < 0,则行为将属于未定义行为。

生成随机数的确切方式由实现定义。例如,它们可能或可能不确定性,也可能或可能不使用隐藏状态。

在与许多利益相关方的讨论中,此 op 已得到有效废弃,因此未来我们计划探索将其移除 (#597)。

输入

标签 名称 类型 限制条件
(I1) a 整数、布尔值或浮点类型的 0 维张量 (C1)、(C2)
(I2) b 整数、布尔值或浮点类型的 0 维张量 (C1)、(C2)
(I3) shape 类型为 si64 的一维张量常量 (C3)
(I4) rng_distribution UNIFORMNORMAL 的枚举 (C2)

输出

名称 类型 限制条件
result 整数、布尔值或浮点类型的张量 (C1-C3)

限制条件

  • (C1) element_type(a) = element_type(b) = element_type(result)
  • (C2) 如果 rng_distribution = NORMAL,则 is_float(a)
  • (C3) shape(result) = shape

示例

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

语义

给定初始状态 initial_state,使用伪随机数生成器算法 rng_algorithm 返回一个填充有均匀随机位数的 output 和更新后的输出状态 output_state。输出一定是 initial_state 的确定性函数,但无法保证在实现之间是确定性的。

rng_algorithm 是以下值之一:

  • DEFAULT:实现定义的算法。
  • THREE_FRY:Threefry 算法的实现定义的变体。*
  • PHILOX:Philox 算法的实现定义的变体。*

* 参见:Salmon et al. SC 2011. 并行随机数字:简单到 1、2、3。

输入

标签 名称 类型 限制条件
(I1) rng_algorithm DEFAULTTHREE_FRYPHILOX 的枚举 (C2)
(I2) initial_state 类型为 ui64 的一维张量 (C1)、(C2)

输出

名称 类型 限制条件
output_state 类型为 ui64 的一维张量 (C1)
output 整数或浮点类型的张量

限制条件

  • (C1) type(initial_state) = type(output_state)
  • (C2) size(initial_state) 的定义如下:
    • 如果为 rng_algorithm = DEFAULT,则由实现定义。
    • 如果 rng_algorithm = THREE_FRY,则设为 2
    • 如果为 rng_algorithm = PHILOX,则为 23

示例

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

语义

operand 张量按元素舍入到最接近的整数(舍弃零),并生成 result 张量。实现 IEEE-754 规范中的 roundToIntegralTiesToAway 运算。对于量化类型,执行 dequantize_op_quantize(round_nearest_afz, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点类型的张量或每个张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

更多示例

round_nearest_even

语义

operand 张量执行向最接近的整数执行元素级舍入,破坏向偶整数的相等关系,并生成一个 result 张量。实现 IEEE-754 规范中的 roundToIntegralTiesToEven 运算。对于量化类型,请执行 dequantize_op_quantize(round_nearest_even, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每个张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点类型的张量或每个张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

更多示例

rsqrt

语义

operand 张量执行按元素反平方根运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 rSqrt
  • 对于复数:复数的平方根的倒数。
  • 对于量化类型:dequantize_op_quantize(rsqrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

更多示例

散射

语义

生成等于 inputs 张量的 results 张量,但使用 update_computationscatter_indices 指定的多个 slice 更新为值 updates

下图使用一个具体示例展示了 updates... 中的元素如何映射到 results... 中的元素。该图表选择了一些 updates... 索引示例,并详细说明了它们对应的 results... 索引。

散射

更正式地讲,对于 index_space(updates[0]) 中的所有 update_index

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
  • update_scatter_index = update_index[update_scatter_dims...]
  • start_index 的定义如下:
    • scatter_indices[si0, ..., :, ..., siN],其中 siupdate_scatter_index 中的各个元素,: 会插入 index_vector_dim 索引(如果 index_vector_dim < rank(scatter_indices))。
    • 否则为 [scatter_indices[update_scatter_index]]
  • 对于 axes(inputs[0]) 中的 d_input
    • 如果 d_input = scatter_dims_to_operand_dims[d_start],则设为 full_start_index[d_input] = start_index[d_start]
    • 否则为 full_start_index[d_input] = 0
  • 对于 axes(inputs[0]) 中的 d_input
    • 如果 d_input = input_batching_dims[i_batching]d_start = scatter_indices_batching_dims[i_batching],则为 full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
    • 否则为 full_batching_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...]
  • full_window_index = [wi0, ..., 0, ..., wiN],其中 wiupdate_window_index 中的单个元素,而 0inserted_window_dimsinput_batching_dims 的索引处插入。
  • result_index = full_start_index + full_batching_index + full_window_index

因此,results = exec(schedule, inputs),其中:

  • schedule 是实现定义的 index_space(updates[0]) 排列。
  • exec([update_index, ...], results) = exec([...], updated_results) 其中:
    • 如果 result_indexshape(results...) 的边界内
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_resultsresults 的副本,其中 results...[result_index] 设置为 updated_values...
    • 否则
    • updated_results = results
  • exec([], results) = results

如果 indices_are_sortedtrue,则实现可以假定 scatter_indices 是按 scatter_dims_to_operand_dims 排序的,否则行为将属于未定义行为。更正式地说,对于 indices(result) 中的所有 i1 < i2full_start_index(i1) <= full_start_index(i2)

如果 unique_indicestrue,则实现可以假定要分散到的所有 result_index 索引都是唯一的。如果 unique_indicestrue,但要分散到的索引不是唯一的,则行为未定义。

输入

标签 名称 类型 限制条件
(I1) inputs 张量数量可变或每个张量量化张量 (C1)、(C2)、(C4-C6)、(C11)、(C13)、(C18)、(C21)、(C23-C24)
(I2) scatter_indices 整数类型的张量 (C4)、(C15)、(C19)、(C22)
(I3) updates 张量数量可变或每个张量量化张量 (C3-C6)、(C8)
(I4) update_window_dims si64 类型的一维张量常量 (C2)、(C4)、(C7-C8)
(I5) inserted_window_dims 类型为 si64 的一维张量常量 (C2)、(C4)、(C9-C11)
(I6) input_batching_dims si64 类型的一维张量常量 (C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20)
(I7) scatter_indices_batching_dims 类型为 si64 的一维张量常量 (C14 - C18)
(I8) scatter_dims_to_operand_dims 类型为 si64 的一维张量常量 (C19-C21)
(I9) index_vector_dim 类型为 si64 的常量 (C4)、(C16)、(C19)、(C22)
(I10) indices_are_sorted 类型为 i1 的常量
(I11) unique_indices i1 类型的常量
(I12) update_computation 函数 (C23)

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C24-C25)

限制条件

  • (C1) same(shape(inputs...))
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`。
  • (C3) same(shape(updates...))
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes),其中:
    • update_scatter_dim_sizes = shape(scatter_indices),但不包含与 index_vector_dim 对应的 scatter_indices 的尺寸。
    • update_window_dim_sizes <= shape(inputs[0]),但不包括与 inserted_window_dimsinput_batching_dims 对应的 inputs[0] 中的尺寸维度。
    • combine 会将 update_scatter_dim_sizes 放置在与 update_scatter_dims 对应的轴上,并将 update_window_dim_sizes 放置在与 update_window_dims 对应的轴上。
  • (C5) 0 < size(inputs) = size(updates) = N
  • (C6) element_type(updates...) = element_type(inputs...)
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8) 0 <= update_window_dims < rank(updates[0])
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims)
  • (C11) 0 <= inserted_window_dims < rank(inputs[0])
  • (C12) is_sorted(input_batching_dims)
  • (C13) 0 <= input_batching_dims < rank(inputs[0]))
  • (C14) is_unique(scatter_indices_batching_dims)
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices)
  • (C16) index_vector_dim not in scatter_indices_batching_dims
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims)
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices)
  • (C23) update_computation 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中 is_promotable(element_type(inputs[i]), Ei)
  • (C24) shape(inputs...) = shape(results...)
  • (C25) [0,N) 中的所有 i 均为 element_type(results[i]) = Ei

示例

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

更多示例

选择

语义

生成一个 result 张量,其中每个元素都是根据 pred 的对应元素的值从 on_trueon_false 张量中选择的。更正式地,result[result_index] = pred_element ? on_true[result_index] : on_false[result_index],其中 pred_element = rank(pred) = 0 ? pred[] : pred[result_index]。对于量化类型,请执行 dequantize_select_quantize(pred, on_true, on_false, type(result))

输入

标签 名称 类型 限制条件
(I1) pred 类型为 i1 的张量 (C1)
(I2) on_true 张量或每张量量化张量 (C1-C2)
(I3) on_false 张量或每张量量化张量 (C2)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C2)

限制条件

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

示例

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

更多示例

select_and_scatter

语义

根据使用 selectinput 张量 reduce_window 的结果,使用 scatter 分散 source 张量的值,并生成一个 result 张量。

下图使用一个具体示例,展示了如何根据 operandsource 计算 result 中的元素。

select_and_scatter

更正式地:

  • selected_values = reduce_window_without_init(...) 和以下输入:

    • inputs = [operand].
    • window_dimensionswindow_stridespadding,按原样使用。
    • base_dilations = windows_dilations = 1
    • body 的定义如下:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    其中,E = element_type(operand)reduce_window_without_init 的运作方式与 reduce_window 完全相同,但底层 reduceschedule(请参阅reduce)不包含 init 值。目前尚未指定如果相应的窗口没有值会发生什么情况 (#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) 其中:

    • source_values = [source[source_index] for source_index in source_indices]
    • 如果 selected_values[source_index] 具有 operand_index 中的 operand 元素,则为 selected_index(source_index) = operand_index
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1-C4)、(C6)、(C8-C11)
(I2) source 张量或按张量量化的张量 (C1)、(C2)
(I3) init_value 0 维张量或按张量量化的张量 (C3)
(I4) window_dimensions 类型为 si64 的一维张量常量 (C2)、(C4)、(C5)
(I5) window_strides si64 类型的一维张量常量 (C2)、(C6)、(C7)
(I6) padding 类型为 si64 的二维张量常量 (C2)、(C8)
(I7) select 函数 (C9)
(I8) scatter 函数 (C10)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C11-C12)

限制条件

  • (C1) element_type(operand) = element_type(source)
  • (C2) shape(source) = num_windows,其中:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
  • (C3) element_type(init_value) = element_type(operand)
  • (C4) size(window_dimensions) = rank(operand)
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(operand)
  • (C7) 0 < window_strides
  • (C8) shape(padding) = [rank(operand), 2]
  • (C9) select 的类型为 (tensor<E>, tensor<E>) -> tensor<i1>,其中 E = element_type(operand)
  • (C10) scatter 的类型为 (tensor<E>, tensor<E>) -> tensor<E>,其中 is_promotable(element_type(operand), E)
  • (C11) shape(operand) = shape(result)
  • (C12) element_type(result) = E

示例

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

更多示例

send

语义

inputs 发送到通道 channel_id 并生成 result 令牌。

如果 is_host_transfertrue,则该操作会将数据传输到主机。否则,它会将数据传输到其他设备。这意味着由实现定义。此标志会重复 channel_type 中提供的信息,因此我们日后计划只保留其中一个(#666)。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量的张量或量化张量
(I2) token token
(I3) channel_id 类型为 si64 的常量
(I4) channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST 的枚举 (C1)
(I5) is_host_transfer 类型为 i1 的常量 (C1)

输出

名称 类型
result token

限制条件

  • (C1) channel_type 的定义如下:
    • 如果 is_host_transfer = true,则设为DEVICE_TO_HOST
    • 否则为 DEVICE_TO_DEVICE

示例

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多示例

shift_left

语义

lhs 张量按 rhs 位数执行元素级向左移运算,并生成 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

输出

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

更多示例

shift_right_arithmetic

语义

lhs 张量按 rhs 位数执行元素级算术右移运算,并生成 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

输出

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

更多示例

shift_right_logical

语义

rhs 位对 lhs 张量执行元素级逻辑右移运算,并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

输出

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

更多示例

签名

语义

按元素返回 operand 的符号,并生成 result 张量。 更正式地,对于每个元素 x,可以使用 Python 语法按如下方式表达语义:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

对于量化类型,请执行 dequantize_op_quantize(sign, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 有符号整数、浮点数或复杂类型的张量,或每个张量量化张量 (C1)

输出

名称 类型 限制条件
result 带符号整数、浮点数、复杂类型或每张量量化张量的张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

更多示例

正弦

语义

operand 张量执行元素级正弦运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 sin
  • 对于复数:复数正弦。
  • 对于量化类型:dequantize_op_quantize(sine, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

更多示例

slice

语义

使用静态计算的起始索引从 operand 中提取一个 slice,并生成 result 张量。start_indices 包含每个维度 slice 的起始索引,limit_indices 包含每个维度 slice 的结束索引(不含边界值),strides 包含每个维度的步长。

更正式地说,result[result_index] = operand[operand_index],其中 operand_index = start_indices + result_index * strides

输入

标签 名称 类型 限制条件
(I1) operand 张量或按张量量化的张量 (C1-C3)、(C5)
(I2) start_indices si64 类型的一维张量常量 (C2)、(C3)、(C5)
(I3) limit_indices 类型为 si64 的一维张量常量 (C2)、(C3)、(C5)
(I4) strides 类型为 si64 的一维张量常量 (C2)、(C4)

输出

名称 类型 限制条件
result 张量或按张量量化的张量 (C1)、(C5)

限制条件

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand)
  • (C4) 0 < strides
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides)

示例

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

更多示例

排序

语义

根据 comparator 沿 dimension 维度对 inputs 的一维切片进行排序,然后生成 results

与其他运算中的类似输入不同,dimension 允许负值,具体语义如下所述。将来,出于一致性原因,我们可能会禁止这种做法 (#1377)。

如果 is_stable 为 true,则排序是稳定的,即比较器认为相等的元素的相对顺序会保留。对于只有一个输入的情况,只有当 comparator(e1, e2) = comparator(e2, e1) = false 时,比较器才会将两个元素 e1e2 视为相等。请参阅下面的形式化说明,了解如何将其推广到多个输入。

更正式地讲,对于 index_space(results[0]) 中的所有 result_index

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
  • result_slice = [ri0, ..., :, ..., riR-1],其中 riNresult_index 中的各个元素,: 插入在 adjusted_dimension 处。
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...)
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
  • 其中 sort 会按非降序对一维 slice 进行排序,并希望 comparator_together 在左侧参数小于右侧第二个参数时返回 true
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together

输入

标签 名称 类型 限制条件
(I1) inputs 张量数量可变或每个张量量化张量 (C1 - C5)
(I2) dimension 类型为 si64 的常量 (C4)
(I3) is_stable 类型为 i1 的常量
(I4) comparator 函数 (C5)

输出

名称 类型 限制条件
results 张量数量可变或每个张量量化张量 (C2)、(C3)

限制条件

  • (C1) 0 < size(inputs)
  • (C2) type(inputs...) = type(results...)
  • (C3) same(shape(inputs...) + shape(results...))
  • (C4) -R <= dimension < R,其中 R = rank(inputs[0])
  • (C5) comparator 的类型为 (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>,其中 Ei = element_type(inputs[i])

示例

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

更多示例

sqrt

语义

operand 张量执行元素级平方根运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 squareRoot
  • 对于复数:复数平方根。
  • 对于量化类型:dequantize_op_quantize(sqrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

更多示例

subtract

语义

对两个张量 lhsrhs 执行元素级减法,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数减法。
  • 对于浮点数:IEEE-754 中的 subtraction
  • 对于复数:复数减法。
  • 对于量化类型:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 整数、浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

更多示例

tan

语义

operand 张量执行元素级正切运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 tan
  • 对于复数:复正切。
  • 对于量化类型:dequantize_op_quantize(tan, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

更多示例

tanh

语义

operand 张量执行元素级双曲正切运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 tanh
  • 对于复数:复数双曲正切。
  • 对于量化类型:
    • dequantize_op_quantize(tanh, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

更多示例

转置

语义

使用 permutationoperand 张量的维度进行排列,并生成 result 张量。更正式地说,result[result_index] = operand[operand_index],其中 result_index[d] = operand_index[permutation[d]]

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C4)
(I2) permutation 类型为 si64 的一维张量常量 (C2-C4)

输出

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3-C4)

限制条件

  • (C1) element_type(result) 由以下公式给出:
    • 如果 !is_per_axis_quantized(operand),则设为 element_type(operand)
    • element_type(operand),不同之处在于 quantization_dimension(operand)quantization_dimension(result) 可以不同。
  • (C2) permutationrange(rank(operand)) 的排列。
  • (C3) shape(result) = dim(operand, permutation...)
  • (C4) 如果 is_per_axis_quantized(result),则 quantization_dimension(operand) = permutation(quantization_dimension(result))

示例

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

更多示例

triangular_solve

语义

使用下三角或上三角系数矩阵解批量线性方程组。

更正式地说,对于 ab,当 left_sidetrue 时,result[i0, ..., iR-3, :, :]op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] 的解;当 left_sidefalse 时,result[i0, ..., iR-3, :, :]x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] 的解,用于求解变量 x,其中 op(a)transpose_a 决定,可以是以下任一项:

  • NO_TRANSPOSE:按原样使用 a 执行操作。
  • TRANSPOSE:对 a 的转置执行操作。
  • ADJOINT:对 a 的共轭转置执行操作。

如果 lowertrue,则仅从 a 的下三角读取输入数据;否则,则从 a 的上三角读取输入数据。输出数据会在同一三角形中返回;另一个三角形中的值由实现定义。

如果 unit_diagonal 为 true,该实现可以假定 a 的对角线元素等于 1,否则行为未定义。

对于量化类型,执行 dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))

输入

标签 名称 类型 限制条件
(I1) a 浮点或复杂类型的张量,或每个张量的量化张量 (C1-C3)
(I2) b 浮点或复杂类型的张量,或每个张量的量化张量 (C1-C4)
(I3) left_side 类型为 i1 的常量 (C3)
(I4) lower 类型为 i1 的常量
(I5) unit_diagonal i1 类型的常量
(I6) transpose_a NO_TRANSPOSETRANSPOSEADJOINT 的枚举

输出

名称 类型 限制条件
result 浮点或复杂类型的张量,或每个张量的量化张量 (C1)

限制条件

  • (C1) baseline_element_type(a) = baseline_element_type(b)
  • (C2) 2 <= rank(a) = rank(b) = R
  • (C3) shape(a)shape(b) 之间的关系如下所示:
    • shape(a)[:-3] = shape(b)[:-3]
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
  • (C4) baseline_type(b) = baseline_type(result)

示例

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

语义

基于值 val 生成 result 元组。

输入

标签 名称 类型 限制条件
(I1) val 值的变参数量 (C1)

输出

名称 类型 限制条件
result tuple (C1)

限制条件

  • (C1) result 的类型为 tuple<E0, ..., EN-1>,其中 Ei = type(val[i])

示例

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

更多示例

uniform_dequantize

语义

根据 operand 类型定义的量化参数,将量化张量 operand 到浮点张量 result 执行元素级转换。

更正式地说,result = dequantize(operand)

输入

标签 名称 类型 限制条件
(I1) operand 量化张量 (C1)、(C2)

输出

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(operand) = shape(result)
  • (C2) element_type(result) = expressed_type(operand)

示例

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

语义

根据 result 类型定义的量化参数,将浮点张量或量化张量 operand 按元素转换为量化张量 result

更正式地说,

  • 如果为 is_float(operand)
    • result = quantize(operand, type(result))
  • 如果为 is_quantized(operand)
    • float_result = dequantize(operand)
    • result = quantize(float_result, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型或量化类型的张量 (C1)、(C2)

输出

名称 类型 限制条件
result 量化张量 (C1)、(C2)

限制条件

  • (C1) shape(operand) = shape(result)
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

示例

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

场景

语义

cond 函数输出 true 时,通过执行 body 函数 0 次或多次生成输出。更正式地,可以使用 Python 语法如下所示来表达语义:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

无限循环的行为尚待确定 (#383)。

输入

标签 名称 类型 限制条件
(I1) operand 可变数量的张量、量化张量或词元 (C1-C3)
(I2) cond 函数 (C1)
(I3) body 函数 (C2)

输出

名称 类型 限制条件
results 可变数量的张量、量化张量或令牌 (C3)

限制条件

  • (C1) cond 的类型为 (T0, ..., TN-1) -> tensor<i1>,其中 Ti = type(operand[i])
  • (C2) body 的类型为 (T0, ..., TN-1) -> (T0, ..., TN-1),其中 Ti = type(operand[i])
  • (C3) type(results...) = type(operand...)

示例

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

更多示例

xor

语义

对两个张量 lhsrhs 执行元素级 XOR,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 XOR。
  • 对于整数:按位异或。

输入

标签 名称 类型 限制条件
(I1) lhs 布尔值或整数类型的张量 (C1)
(I2) rhs 布尔值或整数类型的张量 (C1)

输出

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

更多示例

方言互操作

目前,实际环境中的 StableHLO 程序有时包含 StableHLO 未定义的操作。

模块、函数、调用和返回

StableHLO 针对 ModuleOp、FuncOp、CallOp 和 ReturnOp 使用上游 MLIR 操作。这样做是为了更好地与现有的 MLIR 机器进行互操作,因为许多实用的传递都是以 FuncOp 和 ModuleOp 为目标编写的,并且许多编译流水线都希望存在这些操作。这些操作会得到完全兼容性保证。如果这些 op 发生任何以不兼容的方式更改(即移除),我们将添加 StableHLO 等效项以保持兼容性。

CHLO

CHLO 运算集包含分解为 StableHLO 的更高级别的操作。目前,我们无法保证 CHLO 的兼容性。为了保证兼容性,必须先使用 chlo-legalize-to-stablehlo 传递,然后才能进行序列化。

形状操作

在社区中,在动态 StableHLO 程序中使用核心 MLIR 方言中的某些操作来执行形状计算是一种常见用例。最常见的操作包括 shape 方言操作(例如 shape_ofnum_elements)、tensor 方言操作(例如 dimfrom_elements),以及内置的 index 类型。

“Dynamicism RFC > O2” 将这些类型标记为超出范围,但出于互操作性目的,我们提供了对 index 类型的部分支持。我们无法保证这些操作或类型的兼容性。shape-legalize-to-stablehlo 传递可用于将这些操作转换为完全受支持的 StableHLO 操作。

已弃用的操作

有几个从 MHLO 继承的 StableHLO 操作已被弃用,并将从 StableHLO 中移除。如需详细了解这些移除操作,请参阅 StableHLO v1.0 Cleanup #2283。这些弃用的跟踪器问题是 #2340

这些操作分为以下几个类别:

  • StableHLO 操作的“不在 HLO 中”类别 - 这些操作最初属于 StableHLO 操作集,但后来被认为不太适合:broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum (#3)。
  • 未使用的操作 - 这些操作在某个时间点可能很有用,但这些操作要么不开发,要么使用这些操作的流水线已重构为不再需要它们。其中包括 maptuple (#598)、get_tuple_elementrngcomplex 比较 #560 和卷积 window_reversal (#1181)。

由于这些运算可以使用现有运算 (broadcastcreate_tokencross-replica-sumdotunary_einsum) 表示,因此可以轻松移除,并且在现有兼容性期限(6 个月)过后将被移除。其他运算仍在探索以移除(einsumget_tuple_elementmaprng torch_index_selecttuplecomplex 比较项、window_reversal)。根据社区反馈,这些运算将被移除,或者添加到规范中并获得全面支持。在这些操作 Future 已知之前,我们只能保证它们在 6 个月内保持兼容性。

执行

顺序执行

通过向 main 函数提供输入值并计算输出值,即可执行 StableHLO 程序。函数的输出值通过执行根位于相应 return 操作中的操作图来计算。

只要执行顺序与数据流一致(即操作在使用前执行),则由实现定义。在 StableHLO 中,所有有副作用的操作都会消耗一个令牌并生成一个令牌(多个令牌可以通过 after_all 多路复用为一个令牌),因此副作用的执行顺序也与数据流保持一致。例如,在以下程序中,有两个可能的执行顺序:%0%1%2return%1%0%2return

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

更正式地说,StableHLO 进程是以下各项的组合:1) StableHLO 程序;2) 操作状态(尚未执行,已经执行)和 3) 进程正在处理的中间值。该过程从 main 函数的输入值开始,通过更新操作状态和中间值的操作图表进行,最后以输出值结束。进一步的形式化规范待定 (#484)。

并行执行

StableHLO 程序可以并行执行,并整理为一个由 num_replicasnum_partitions(均为 ui32 类型)组成的 2D 进程网格。

StableHLO 进程网格中,有 num_replicas * num_partitions 个 StableHLO 进程同时执行。每个进程都有一个唯一的 process_id = (replica_id, partition_id),其中 replica_ids = range(num_replicas) 中的 replica_idpartition_ids = range(num_partitions) 中的 partition_id 均为类型 ui32

每个程序的进程网格的大小是静态已知的(未来,我们计划将其作为 StableHLO 程序的显式部分#650),每个进程在进程网格中的位置也是静态已知的。每个进程都可以通过 replica_idpartition_id 操作获取其在进程网格中的位置。

在流程网格中,程序可以全部相同(采用“单个程序、多种数据”样式),也可以全部不同(采用“多个程序、多种数据”样式),还可以介于两者之间。未来,我们计划引入对定义并行 StableHLO 程序的其他惯用法则的支持,包括 GSPMD (#619)。

在进程网格中,进程大多彼此独立 - 它们具有单独的操作状态、单独的输入/中间/输出值,并且大多数操作在进程之间单独执行,但下文中介绍的少数集体操作除外。

由于大多数操作的执行仅使用同一进程中的值,因此通常可以通过名称引用这些值,这样做不会产生歧义。不过,在描述集合运算的语义时,这还不够,因此引入了 name@process_id 这种表示法来引用特定进程中的值 name。(从这个角度来看,不符合条件的 name 可被视为 name@(replica_id(), partition_id()) 的简写形式)。

除了点对点通信和集合运算(如下所述)引入的同步之外,各进程的执行顺序由实现定义。

点对点通信

StableHLO 进程可以通过 StableHLO 通道相互通信。渠道由 si64 类型的正 ID 表示。通过各种操作,可以向通道发送值并从通道接收值。

进一步的形式化(例如这些通道 ID 的来源、进程程序如何知晓这些 ID,以及它们引入了哪种类型的同步)尚待确定 (#484)。

流式通信

每个 StableHLO 进程都可以访问两个流式接口:

  • 可读取的Infeed
  • 可写入的出站 Feed

与用于在进程之间进行通信(因此两端都有进程)的渠道不同,inFeed 和 outFeed 的另一端实现则定义。

进一步的形式化(例如,流式通信如何影响执行顺序以及它引入了哪种同步)尚待确定 (#484)。

集体操作

StableHLO 中有 6 个集合运算:all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter。所有这些操作都将 StableHLO 进程网格中的进程拆分为 StableHLO 进程组,并在每个进程组内执行联合计算,与其他进程组分开计算。

在每个进程组内,集体操作可能会引入同步屏障。进一步的形式化(例如详细说明此同步的确切发生时间、进程到达此屏障的确切方式,以及如果进程未到达此屏障会发生什么情况)尚待确定 (#484)。

如果进程组涉及跨分区通信(即进程组中有分区 ID 不同的进程),则执行集合操作需要使用通道,并且集合操作必须提供类型为 si64 的正 channel_id。跨副本通信不需要通道。

集合运算执行的计算因运算而异,如上文中的各个运算部分所述。不过,将进程网格拆分为进程组的策略在这些操作之间共享,并在本部分中进行了介绍。更正式地说,StableHLO 支持以下四种策略。

cross_replica

每个进程组内只会发生跨副本通信。此策略接受 replica_groups(一个包含副本 ID 列表的列表),并计算 replica_groupspartition_ids 的笛卡尔积。replica_groups 必须包含唯一元素并涵盖所有 replica_ids。更正式地,使用 Python 语法:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,对于 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica 将生成 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]

cross_partition

只有跨分区通信在每个进程组内发生。此策略采用 partition_groups(分区 ID 列表的列表),并计算 partition_groupsreplica_ids 的笛卡尔积。partition_groups 必须包含唯一的元素,并涵盖所有 partition_ids。更正式地,使用 Python 语法:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,对于 partition_groups = [[0, 1]]num_replicas = 4cross_partition 将生成 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]

cross_replica_and_partition

每个进程组中都可能会发生跨副本和跨分区通信。此策略采用 replica_groups(一个复制 ID 列表的列表),并通过 partition_ids 计算每个 replica_group 的笛卡尔积。replica_groups 必须包含唯一的元素,并涵盖所有 replica_ids。更正式地,使用 Python 语法:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

例如,对于 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica_and_partition 将生成 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]

flattened_ids

此策略将 flattened_id_groups(采用 replica_id * num_partitions + partition_id 形式的“扁平化”进程 ID 列表)转换为进程 ID。flattened_id_groups 必须包含唯一的元素,并涵盖所有 process_ids。更正式地,使用 Python 语法:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

例如,对于 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2flattened_ids 将生成 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]

准确率

目前,StableHLO 不保证数值精度,但未来可能会发生变化 (#1156)。

量化操作的执行语义

量化 StableHLO 运算的解释可能因硬件要求和功能而异。例如,某些硬件可能会选择使用“去量化,执行浮点运算,最后量化”策略来解释量化运算。其他人可能会使用整数算术执行整个计算。因此,对量化 StableHLO 操作的解释完全取决于具体实现。对混合量化 (#1575) 的解读应基于规范中规定的语义(通过 1792)。

错误

StableHLO 程序通过针对各个操作的一组广泛约束条件进行验证,这可在运行时排除许多类错误。不过,错误情况仍然可能实现,例如整数溢出、出界访问等。除非明确指出,否则所有这些错误都会导致实现定义的行为,但将来可能会发生变化 (#1157)。

浮点异常

作为此规则的例外情况,StableHLO 程序中的浮点异常具有明确定义的行为。导致 IEEE-754 标准定义的异常(无效操作、除零、溢出、下溢或不精确异常)的操作会产生默认结果(如标准中所定义),并继续执行,而不会引发相应的状态标志;这类似于标准中的 raiseNoFlag 异常处理。非标准运算(例如复杂算术和某些超越函数)的异常由实现定义。

形状不匹配

StableHLO 支持动态形状的张量。不过,形状必须在运行时保持一致,否则行为将处于未定义状态。StableHLO 不会明确提供可断言张量在运行时具有给定形状的操作。生成正确的代码是生产者的责任。

具体而言,以下程序是有效的。不过,在运行时,%arg0%arg1 的确切形状必须相同,否则程序的行为将不明确:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

在描述语法时,本文档使用了经过修改的 ISO 版 EBNF 语法(ISO/IEC 14977:1996Wikipedia),其中进行了两项修改:1) 使用 ::= 而非 = 定义规则,

2) 串联使用并列(而非 ,)来表示。

如需描述语义(即“类型”“常量”和“运算”部分),我们使用基于 Python 语法并扩展了对简洁表达数组运算的支持的公式,如下所述。这对于小段代码非常有用,但在极少数情况下,如果需要较大的代码段,我们会使用始终明确引入的原始 Python 语法。

公式

我们将通过 dot_general 规范中的示例来探索公式的运作方式。此操作的一个约束条件如下所示:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

此公式中使用的名称来自两个来源:1) 全局函数,即 dim;2) 相应程序元素的成员定义,即 dot_general 的“输入”部分中定义的 lhslhs_batching_dimensionsrhsrhs_batching_dimensions 输入。

如上所述,此公式的语法基于 Python,并添加了一些以简洁为导向的扩展。为了理解这个公式,我们先将其转换为普通的 Python 语法

A) 在这些公式中,我们使用 = 表示等式,因此,要想获得 Python 语法,第一步是将 = 替换为 ==,如下所示:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) 此外,这些公式支持省略号 (...),省略号可将标量表达式转换为张量表达式。简而言之,f(xs...) 大致表示“对于张量 xs 中的每个标量 x,计算一个标量 f(x),然后以张量结果的形式返回所有这些标量结果”。在纯 Python 语法中,我们的示例公式会变成:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

得益于省略号,通常可以避免在单个标量级别进行操作。不过,在某些棘手的情况下,可以使用较低级别的半非正式语法,如 gather 规范中的 start_indices[bi0, ..., :, ..., biN] 公式。为求简洁,我们不会提供将此类语法转换为纯 Python 的确切形式化方法,希望您能根据具体情况直观地理解它。如果某些特定公式看起来不透明,请告知我们,我们将努力改进。

此外,您会注意到,公式使用省略号来扩展各种列表,包括张量、张量列表(例如,可能源于可变数量的张量)等。这是另一个我们无法提供确切形式化的领域(例如,列表甚至不是 StableHLO 类型系统的直观性的一部分),而是依赖于其理解性。

C) 我们采用的最后一个值得注意的表示法是隐式广播。虽然 StableHLO 运算集不支持隐式广播,但公式支持,这也是为了简化服务。简而言之,如果在需要张量的上下文中使用标量,该标量将被广播到预期的形状。

继续使用 dot_general 示例,下面是另一个约束条件:0 <= lhs_batching_dimensions < rank(lhs)。如 dot_general 规范中所定义,lhs_batching_dimensions 是一个张量,但 0rank(lhs) 都是标量。应用隐式广播后,该公式将变为 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]

应用于特定 dot_general 运算时,此公式将求值为布尔值张量。当公式用作约束条件时,如果公式的计算结果为 true 或仅包含 true 元素的张量,约束条件将保持不变。

名称

在公式中,词法范围包括:1) 全局函数;2) 成员定义;

3) 本地定义。下面列出了全局函数的列表。元素定义列表取决于应用该符号的程序元素:

  • 对于操作,成员定义包括“输入”和“输出”部分中介绍的名称。
  • 对于其他所有内容,成员定义包括节目元素的结构部分,以相应的 EBNF 非终端命名。大多数情况下,这些结构部分的名称是通过将非终结符的名称转换为蛇形命名法(例如 IntegerLiteral => integer_literal)获得的,但有时名称会在此过程中缩写(例如 QuantizationStorageType => storage_type),在这种情况下,名称的引入方式与操作规范中的“输入”/“输出”部分类似,即明确引入。
  • 此外,成员定义始终包含 self 来引用相应的计划元素。

计算公式时,它们使用以下类型的值: 1) Value(实际值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;它们始终知道自己的类型); 2) Placeholder(未来的值,例如 lhsrhsresult;其实际值尚不已知,只有其类型已知); 3) Type(“类型”部分中定义的类型); 4) Function(“函数”部分中定义的全局函数)。

名称可能指代不同的值,具体取决于上下文。更具体地说,操作的“Semantics”(语义)部分(以及其他程序元素的等效部分)定义了运行时逻辑,因此所有输入均可作为 Value 使用。与之相反,操作(及等效项)的“约束条件”部分定义了“编译时”逻辑,即通常在运行时之前执行的操作,因此只有常量输入可用作 Value,其他输入只能用作 Placeholder

名称 在“Semantics”(语义)中 在“约束条件”中
全局函数 Function Function
常量输入 Value Value
非常量输入 Value Placeholder
输出 Value Placeholder
本地定义 取决于定义 取决于定义

我们来看一个 transpose 操作示例:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

对于此操作,permutation 是常量,因此在语义和约束条件中均可用作 Value。相比之下,operandresult 在语义中可用作 Value,但在约束条件中只能用作 Placeholder

函数

类型的构建

没有可用于构造类型的函数。我们改为直接使用类型语法,因为它通常更简洁。例如,(tensor<E>, tensor<E>) -> (tensor<E>) 而非 function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])

类型函数

  • element_type 基于张量类型和量化张量类型进行定义,并分别返回相应 TensorTypeQuantizedTensorTypeTensorElementTypeQuantizedTensorElementType 部分。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is not None 的快捷方式。

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is None 的简写。

  • is_promotable(x: Type, y: Type) -> bool 用于检查类型 x 是否可以提升为类型 y。当 xy 均为 QuantizedTensorElementType 时,促销优惠仅应用于 storage_type。此特定版本的提升目前用于减少计算上下文(如需了解详情,请参阅 RFC)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Valueis_quantized_tensor_element_type(x) 的快捷方式。

  • is_type_name(x: Value | Placeholder | Type) -> Value。适用于所有类型。例如,如果 xFloatType,则 is_float(x) 会返回 true。如果 x 是值或占位符,则此函数是 is_type_name(type(x)) 的快捷方式。

  • max_value(x: Type) -> Value 会返回 TensorElementType 的最大值。如果 x 不是 TensorElementType,则返回 None

  • min_value(x: Type) -> Value 会返回 TensorElementType 的可能最小值。如果 x 不是 TensorElementType,则返回 None

  • member_name(x: Value | Placeholder | Type) -> Any。适用于所有类型的所有成员定义 member_name。例如,tensor_element_type(x) 会返回相应 TensorTypeTensorElementType 部分。如果 x 是值或占位符,则此函数是 member_name(type(x)) 的快捷方式。如果 x 不是具有适当成员的类型,或者此类类型的值或占位符,则返回 None

  • is_empty_algorithm(*args: Type) 会检查所有点算法字段是否都设置为 None。由于点算法具有实现定义的默认行为,因此指定默认值是不正确的,因此需要这样做。

值的构建

  • operation_name(*xs: Value | Type) -> Value。适用于所有操作。 例如,add(lhs, rhs) 接受两个张量值 lhsrhs,并返回使用这些输入评估 add 操作的输出。对于某些操作(例如 broadcast_in_dim),其输出类型是“承载负载”的,即需要用于评估操作。在这种情况下,该函数将这些类型作为参数。

值函数

  • 您可以使用 Python 的所有运算符和函数。例如,Python 中的订阅切片符号都可以用于对张量、量化张量和元组进行编制索引。

  • to_destination_type(x: Value, destination_type: Type) -> Value 在张量上定义,并根据 type(x)destination_type 返回 x 的转换值,如下所示:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

我们已就如何合并 convertuniform_quantizeuniform_dequantize 运算进行了早期讨论 (#1576)。合并后,我们不需要上述函数,而是可以改用 convert 的操作名称。

  • is_nan(x: Value) -> Value 在张量上定义,如果 x 的所有元素均为 NaN,则返回 true;否则返回 false。如果 x 不是张量,则返回 None

  • is_sorted(x: Value) -> Value 在张量上定义,如果 x 的元素按其索引的升序字典顺序排序,则返回 true;否则,返回 false。如果 x 不是张量,则返回 None

  • is_unique(x: Value) -> Value 在张量上定义,如果 x 没有重复元素,则返回 true;否则,返回 false。如果 x 不是张量,则返回 None

  • member_name(x: Value) -> Any 是针对所有值的所有成员定义 member_name 定义的。例如,real_part(x) 会返回相应 ComplexConstantRealPart 部分。如果 x 不是具有适当成员的值,则返回 None

  • same(x: Value) -> Value 在张量上定义,如果 x 的元素都彼此相等,则返回 true;否则返回 false。如果张量没有元素,则计为“所有元素都相等”,即函数返回 true。如果 x 不是张量,则返回 None

  • split(x: Value, num_results: Value, axis: Value) -> Value 在张量上定义,并沿轴 axis 返回 xnum_results 切片。如果 x 不是张量或 dim(x, axis) % num_results != 0,则返回 None

  • is_defined_in_parent_scope(x: Value) -> Value 基于字符串进行定义,如果 x 是与相关运算的父函数在同一作用域中定义的函数的名称,则返回 true

  • is_namespaced_op_name(x: Value) -> Value 基于字符串进行定义,如果 x 是有效的运算名称(即符合以下正则表达式),则返回 true[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

形状计算

  • axes(x: Value | Placeholder | Type) -> Valuerange(rank(x)) 的快捷方式。

  • dim(x: Value | Placeholder | Type, axis: Value) -> Valueshape(x)[axis] 的快捷方式。

  • dims(x: Value | Placeholder | Type, axes: List) -> Listlist(map(lambda axis: dim(x, axis), axes)) 的快捷方式。

  • index_space(x: Value | Placeholder | Type) -> Value 在张量上定义,并返回按字典顺序升序(即 [0, ..., 0][0, ..., 1]、...、shape(x) - 1)排列的相应 TensorTypesize(x) 索引。如果 x 不是张量类型、量化张量类型,也不是这些类型的值或占位符,则返回 None

  • rank(x: Value | Placeholder | Type) -> Valuesize(shape(x)) 的快捷方式。

  • shape(x: Value | Placeholder | Type) -> Value 在“类型函数”部分中通过 member_name 进行定义。

  • size(x: Value | Placeholder | Type) -> Valuereduce(lambda x, y: x * y, shape(x)) 的快捷方式。

量化计算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Typeelement_type(baseline_type(x)) 的简写。

  • baseline_type 基于张量类型和量化张量类型进行定义,并将它们转换为“基准”,即具有相同形状但元素类型的量化参数重置为默认值的类型。这是一种非常实用的技巧,可用于统一比较张量和量化张量类型,这在很多情况下都需要。对于量化类型,这支持忽略量化参数来比较类型,即 shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension(对于按轴量化类型)必须完全匹配,但 scaleszero points 可以不同。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize 基于量化张量类型定义,并将其转换为浮点张量类型。为此,系统会使用与量化元素类型关联的零点和比例,将表示存储类型的整数值的量化元素转换为表示类型的相应浮点值。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize 基于浮点张量类型定义,并将其转换为量化张量类型。为此,您可以使用与量化元素类型关联的零点和缩放比例,将表示类型的浮点值转换为存储类型的相应整数值。
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize 用于指定对量化张量进行逐元素计算。它会解量化(即将量化元素转换为其表达类型),然后执行操作,最后再量化(即将结果转换回其存储类型)。目前,此函数仅适用于逐张量量化。我们正在研究按轴量化 (#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op 用于指定混合操作的仅权重量化,它接受浮点类型的 lhs 和量化类型的 rh。它会将量化输入反量化为表示的类型,并以浮点数执行计算。浮点型 lhs 张量的元素类型和量化型 rhs 张量的表达类型应相同。
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

网格计算

  • cross_partition(replica_groups: Value) -> Value。请参阅上面的“cross_replica”部分。

  • cross_replica(replica_groups: Value) -> Value。请参阅上面的“cross_replica”部分。

  • cross_replica_and_partition(replica_groups: Value) -> Value。请参阅上面的“cross_replica_and_Partition”部分。

  • flattened_ids(replica_groups: Value) -> Value。请参阅上文中的“flattened_ids”部分。

动力

StableHLO 值可以具有动态维度大小,例如 tensor<?xi64>。不过,StableHLO 值不能具有动态维度数(非排名动态性,例如 tensor<*xi64>)。运算数和结果可以使用动态维度大小,即使大小存在限制也是如此。系统会尽可能对约束条件进行静态验证,否则会将其推迟到运行时,不匹配将导致未定义的行为。如需查看示例,请参阅下文。

一元元素级运算的形状不匹配

请考虑以下玩具程序:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

这样的程序并不常见,因为知道结果的形状而不知道输入的形状并不常见。尽管如此,这仍是一个有效的 StableHLO 程序。由于运算元的确切形状未知,因此无法对此程序中的 abs 运算进行静态验证。不过,形状肯定是兼容的,并且可以进行静态检查:? 在运行时可能会变成 2,但不会出现任何问题。不过,? 也可能变成某个其他整数,在这种情况下,行为将属于未定义行为。

请注意,如果结果中的尺寸大小是动态的,则不能出现未定义的行为。事实上,没有“预期”大小,因此不会出现不匹配的情况。

二进制元素级操作的形状不匹配

请考虑以下玩具程序:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

对于按元素执行的二进制运算,输入和结果的形状在运行时必须一致。在编译时,静态尺寸必须相同,否则只需兼容即可。如果输入中的任何维度是动态的,则在运行时可能会出现未定义的行为,因为动态大小可能与另一个运算数(无论是静态还是动态)中的相应大小不匹配。如果所有输入都是静态的,则结果是动态的还是静态的无关紧要:静态已知维度将以静态方式进行检查,而动态维度不会施加任何约束。

将输出形状用作运算元的运算的形状不匹配

请参考以下玩具计划:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

运行时形状运算符中的值必须与结果的形状一致,否则行为未定义。也就是说,在运行时,%arg0 的值必须为 dense<[3, 4]> : tensor<2xi32>。如果形状运算数是常量,则可以静态验证。如果结果形状是完全动态的,则不会出现不匹配的情况。