StableHLO 规范

StableHLO 是用于机器中高级操作 (HLO) 的操作集 学习 (ML) 模型。StableHLO 充当不同应用之间的可移植性层, 机器学习框架和机器学习编译器:可生成 StableHLO 程序的机器学习框架 兼容使用 StableHLO 程序的机器学习编译器。

我们的目标是通过创建更多 API 来简化和加速机器学习开发, 各种机器学习框架(例如 TensorFlow、JAX 和 PyTorch)和机器学习编译器(例如 XLA 和 IREE)。为此, 文档提供了 StableHLO 编程语言的规范。

本规范包含三个主要部分。首先, 程序部分介绍了 StableHLO 程序的结构 其中包含 StableHLO 函数,而其本身由 StableHLO 操作组成。 在该结构内,Ops 部分指定 单个操作。Execution 部分提供了 这些运算在一个程序内一起执行。最后, 表示法部分讨论了 规范

如需查看先前版本的 StableHLO 中的规范,请在以下位置打开代码库: 感兴趣的标记版本。 例如 StableHLO v0.19.0 规范。 如需查看 StableHLO 的每个次要版本递增时发生的更改,请参阅 VhloDialect.td 中的版本日志。

计划

Program ::= {Func}

StableHLO 程序由任意数量的 StableHLO 函数组成。 下面是一个包含函数 @main 且包含 3 个输入的示例程序 (%image%weights%bias)和 1 个输出。函数正文 共有 6 个操作

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

函数

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 函数(也称为命名函数)具有 标识符、输入/输出和正文。今后,我们计划 为函数引入额外的元数据,以实现更好的兼容性 (#425#626#740#744)。

标识符

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO 标识符与许多编程中的标识符类似 具有两个特点:1) 所有标识符都有 区分不同类型的标识符;2) 值标识符 完全数字,以简化 StableHLO 程序的生成。

类型

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 类型分为值类型(也称为 一类类型),用于表示 StableHLO 值和非值类型 用于描述其他程序元素StableHLO 类型与 许多编程语言,其主要特性是 StableHLO 特定于领域的性质,这会导致一些不寻常的结果(例如标量类型) 不是值类型)。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

张量类型表示张量,即多维数组。他们有 形状元素类型,其中形状表示非负或 未知的尺寸大小,按 维度(也称为“轴”),编号为从 0R-1通过 维度的数量 R 称为“排名”。例如,tensor<2x3xf32> 是 一个形状为 2x3 且元素类型为 f32 的张量类型。它有两个维度 (也就是两个轴)- 第 0 个维度和第 1 个维度 - 其尺寸 分别为 2 和 3。其排名为 2。

形状可以是部分未知或完全未知(动态),例如tensor<?x2xf64> 部分未知,tensor<?x?xf64> 完全未知。动态 尺寸使用 ? 表示。无法对形状取消排名。

未来,我们计划探索将张量类型扩展到 尺寸和元素类型,例如添加版式 (#629) 和稀疏性 (#1078)。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
名称 类型 限制条件
storage_type 整数类型 (C1-C3)、(C8)
storage_min 整数常量 (C1)、(C3)、(C7)
storage_max 整数常量 (C2)、(C3)、(C7)
expressed_type 浮点类型 (C4)
quantization_dimension 可选整数常量 (C10 - C12)
scales 可变数的浮点常量 (C4-C6)、(C9)、(C10)、(C13)
zero_points 整数常量可变数量 (C7-C9)

量化元素类型表示以下类型存储类型的整数值: 从 storage_minstorage_max(含)的范围,对应于 所表达类型的浮点值。对于给定的整数值 i, 相应的浮点值 f 可按以下公式进行计算: f = (i - zero_point) * scale,其中 scalezero_point 被调用 量化参数storage_minstorage_max 是可选的 但有默认值 min_value(storage_type)max_value(storage_type)。量化元素类型具有 以下限制:

  • (C1) type(storage_min) = storage_type
  • (C2) type(storage_max) = storage_type
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C4) type(scales...) = expressed_type
  • (C5) 0 < scales
  • (C6) is_finite(scales...)
  • (C7) storage_min <= zero_points <= storage_max
  • (C8) type(zero_points...) = storage_type
  • (C9) size(scales) = size(zero_points)
  • (C10) 如果 is_empty(quantization_dimension),则 size(scales) = 1
  • (C11) 0 <= quantization_dimension

目前,QuantizationScale 是一个浮点常量,但 对基于整数的标度(用乘数和 变化。我们计划在不久的将来对此进行探索 (#1404)。

围绕 QuantizationZeroPoint 的语义有持续讨论, 包括类型、值以及 量化张量类型中可能有多个零点。基于 因此有关零点的规范可能会发生变化 。 (#1405)。

另一个正在进行的讨论涉及 QuantizationStorageMin 的语义 和 QuantizationStorageMax 来确定是否应 加在这些值和量化张量的值上 (#1406)。

最后,我们计划探索如何表示未知量表和零 与我们计划探索表示未知路径的 尺寸尺寸 (#1407)。

量化张量类型表示包含量化元素的张量。这些 张量与常规张量完全相同,只不过它们的元素 量化元素类型,而不是常规元素类型。

在量化张量中,量化可以是“每个张量”,也就是说,具有 一个 scalezero_point(针对整个张量),或者可以是 per-axis, 也就是说,有多个 scaleszero_points,每个 Slice 包含一对 特定维度quantization_dimension。更正式地说,在张量 t 中 采用每轴量化后,有 dim(t, quantization_dimension) 个切片 (共 quantization_dimension 个):t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], 等等。i 切片中的所有元素都使用 scales[i]zero_points[i] 作为 量化参数。量化张量类型具有以下特征 限制条件:

  • 对于每张量量化: <ph type="x-smartling-placeholder">
      </ph>
    • 没有其他限制。
  • 对于每轴量化: <ph type="x-smartling-placeholder">
      </ph>
    • (C12) quantization_dimension < rank(self)
    • (C13) dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

令牌类型表示令牌,即生成和使用的不透明值 执行某些操作令牌用于对操作施加执行顺序 如执行部分中所述。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

元组类型表示元组,即异构列表。元组是一种旧版 功能只是为了与 HLO 兼容而存在。在 HLO 中,元组是 用于表示可变的输入和输出。在 StableHLO 中,可变输入和 输出原生支持,而元组在 StableHLO 中的唯一用途是 全面表示 HLO ABI,例如Ttuple<T>tuple<tuple<T>> 可能会因具体情况而异, 实施。未来,我们计划更改 HLO ABI 这或许可以让我们从 StableHLO 中移除元组类型 (#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

元素类型表示张量类型的元素。与许多编程中的 因此这些类型在 StableHLO 中不是第一类。这意味着 StableHLO 程序无法直接表示这些类型的值(因此, 惯用方式是使用 0 维张量表示 T 类型的标量值 类型为 tensor<T> 的值)。

  • 布尔值类型表示布尔值 truefalse
  • 整数类型可以是有符号 (si),也可以是无符号 (ui),并且具有 其中一个受支持的位宽度(248163264)。 有符号 siN 类型表示从 -2^(N-1)2^(N-1)-1 的整数值 含符号的 uiN 类型表示从 00 的整数值 2^N-1(含)。
  • 浮点类型可以是下列之一: <ph type="x-smartling-placeholder">
  • 复杂类型表示具有实部的复杂值 以及属于同一元素类型虚部。支持的复杂 类型为 complex<f32>(两个部分均为 f32 类型)和 complex<f64> (两个部分均为 f64 类型)。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

函数类型同时表示命名和匿名函数。它们具有 输入类型(-> 左侧的类型列表)和输出类型 (-> 右侧的类型列表)。在许多编程中 函数类型是第一类,但在 StableHLO 中并非如此。

StringType ::= 'string'

String type 表示字节序列。与许多编程中的 语言,因此字符串类型不是 StableHLO 中的第一类,仅用于 指定节目元素的静态元数据。

操作

StableHLO 运算(也称为“运算”)表示闭合集 机器学习模型中的概要操作。如上所述 StableHLO 语法深受 MLIR 的启发,而 MLIR 未必是 但可以说是最适合 StableHLO 的目标 在机器学习框架和机器学习编译器之间建立更高的互操作性。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO 操作(也称为 ops)的名称为, 输入/输出和签名。名称由 stablehlo. 前缀和 一个助记符,用于唯一标识其中一个支持的操作。请参阅下文,了解 所有受支持操作的完整列表。

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

操作会使用输入并生成输出。输入分为 输入值(在执行期间计算)、输入函数(提供的 静态引用,因为在 StableHLO 中,函数不是一类值)和 输入属性(也以静态方式提供)。输入和输出的种类 操作消耗和生成的内容取决于其助记符。例如,add 操作使用 2 个输入值并生成 1 个输出值。相比之下, select_and_scatter 操作会使用 3 个输入值、2 个输入函数和 3 种输入属性。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

输入函数(也称为“匿名函数”)是非常实用的 与命名函数类似,不同之处在于:1) 它们没有标识符(因此 名称为“匿名”),2) 未声明输出类型(输出类型 根据函数中的 return 操作推断出)。

输入函数的语法包含目前未使用的部分(请参阅 Unused 生产文件),以便与 MLIR 兼容。在 MLIR 中, 有一个更宽泛的概念,即“区域”该类型可以包含多个“屏蔽设置” 操作通过跳跃操作连接在一起。这些区块的 ID 分别对应于 添加到 Unused 正式版中,以便它们可以彼此区分。 StableHLO 没有跳转操作,因此 MLIR 语法的相应部分是 未使用(但仍存在)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

输入属性具有名称和值,并且是受支持的值之一 常量。它们是为节目指定静态元数据的主要方式 元素。例如,concatenate 操作使用属性 dimension 来 指定串联其输入值所依据的维度。同样, slice 操作使用多个属性,如 start_indiceslimit_indices 指定用于对输入值进行切片的边界。

目前,外部的 StableHLO 程序有时包含属性 本文档未作介绍。今后,我们计划 要么将这些属性吸收到 StableHLO 操作集中,要么禁止它们 出现在 StableHLO 程序中。在此期间,您可以查看这些 属性:

  • layout (#629)。
  • mhlo.frontend_attributes (#628)。
  • mhlo.sharding (#619)。
  • output_operand_aliases (#740)。
  • 位置元数据 (#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

操作签名包含所有输入值的类型(即, -> 的左侧)和所有输出值的类型( -> 右侧的类型)。严格来说,输入类型 冗余,而输出类型也几乎总是冗余(因为对于 大多数 StableHLO 操作,输出类型都可以根据输入推断出来)。尽管如此, 签名是 StableHLO 语法的一部分,以便与 MLIR 兼容。

下面是一个助记符为 select_and_scatter 的示例操作。消耗 3 输入值(%operand%source%init_value),2 个输入函数 和 3 个输入属性(window_dimensionswindow_stridespadding)。 请注意,该操作的签名如何仅包含其输入值的类型 (但不包括内嵌提供的输入函数和属性的类型)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

常量

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 常量有一个字面量和一个类型,它们共同表示 StableHLO 值。通常,该类型是常量语法的一部分, 在含义明确时(例如,布尔值常量明确为 i1 类型, 而整数常量可以有多种可能的类型)。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

布尔值常量表示布尔值 truefalse。布尔值 常量的类型为 i1

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整数常量通过使用十进制数或 十六进制记数法。其他基地,如二进制或八进制,均不支持。 整数常量具有以下限制:

  • (C1) is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮点常量:通过 使用十进制数或科学计数法。此外,十六进制记数法可以是 用于以 相应的类型。浮点常量具有以下限制:

  • (C1) 如果使用非十六进制记数法, is_wellformed(float_literal, float_type)
  • (C2) 如果使用十六进制记数法, size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

复数常量:使用实数部分的列表表示复数值 (在前)和虚部(第二)。例如: (1.0, 0.0) : complex<f32> 表示 1.0 + 0.0i(0.0, 1.0) : complex<f32> 表示 0.0 + 1.0i。系统会按照 然后由实现定义的内存存储在内存中复常数 具有以下限制:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type))
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

张量常量,使用通过以下方式指定的嵌套列表来表示张量值: NumPy 表示法。例如:dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 表示一个张量值,具有以下从索引到元素的映射: {0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6。这些元素随后在内存中的存储顺序为 实现定义的。张量常量具有以下限制:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)),其中: <ph type="x-smartling-placeholder">
      </ph>
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
  • (C2) has_shape(tensor_literal, shape(tensor_type)),其中: <ph type="x-smartling-placeholder">
      </ph>
    • has_shape(element_literal: Syntax, []) = true
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
    • 否则为 false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

量化张量常量使用相同的 以张量常量的形式表示,其中元素指定为其 存储类型。量化张量常量具有以下限制:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

字符串字面量由使用 ASCII 字符和 转义序列。它们与编码无关,因此对它们的解释 由实现定义。字符串字面量的类型为 string

Ops Agent 可以

abs

语义

operand 张量执行元素级绝对运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于有符号整数:整数模数。
  • 对于浮点数:IEEE-754 中的 abs
  • 对于复数:复模数。
  • 对于量化类型:dequantize_op_quantize(abs, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 带符号整数、浮点数、复杂类型或每张量量化张量的张量 (C1 - C2)

Outputs

名称 类型 限制条件
result 带符号整数或浮点类型的张量或每张量量化张量 (C1 - C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) baseline_element_type(result) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_complex(operand),则为 complex_element_type(element_type(operand))
    • 否则为 baseline_element_type(operand)

示例

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

更多示例

add

语义

对两个张量 lhsrhs 执行元素相加,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 OR。
  • 对于整数:整数加法。
  • 对于浮点数:IEEE-754 中的 addition
  • 对于复数:复杂加法。
  • 对于量化类型:dequantize_op_quantize(add, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或量化张量 (C1 - C6)
(I2) rhs 张量或量化张量 (C1-C5)、(C7)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1 - C7)

限制条件

  • 如果运算使用非量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C1) type(lhs) = type(rhs) = type(result)
  • 如果运算使用量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result)
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
    • (C6) 如果 is_per_axis_quantized(lhs),则 quantization_dimension(lhs) = quantization_dimension(result)
    • (C7) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs) = quantization_dimension(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

更多示例

after_all

语义

确保产生 inputs 的操作在执行任何 依赖于 result 的操作。执行这项操作不起任何作用 它的存在只是为了建立从 resultinputs 的数据依赖关系。

输入

标签 名称 类型
(I1) inputs token的可变数

Outputs

名称 类型
result token

示例

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

更多示例

all_gather

语义

在 StableHLO 进程网格中的每个进程组内,串联值 沿着 all_gather_dim 分布的每个进程的 operands 张量,并生成 results 张量。

该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • cross_replica(replica_groups) 如果 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果 channel_id > 0 and use_global_device_ids = true

之后,在每个 process_group 中:

  • 全部operands...@receiver = [operand@sender for sender in process_group]process_group”中的“receiver”。
  • 全部results...@process = concatenate(operands...@process, all_gather_dim)process_group”中的“process”。

输入

标签 名称 类型 限制条件
(I1) operands 可变数量张量或每个张量量化张量 (C1)、(C6)
(I2) all_gather_dim si64 类型的常量 (C1)、(C6)
(I3) replica_groups si64 类型的二维张量常量 (C2-C4)
(I4) channel_id si64 类型的常量 (C5)
(I5) use_global_device_ids i1 类型的常量 (C5)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C6)

限制条件

  • (C1) 0 <= all_gather_dim < rank(operands...)
  • (C2) is_unique(replica_groups)
  • (C3) size(replica_groups) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C4) 0 <= replica_groups < size(replica_groups)
  • (C5) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C6) type(results...) = type(operands...),以下项除外: <ph type="x-smartling-placeholder">
      </ph>
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)

示例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

更多示例

all_reduce

语义

在 StableHLO 进程网格中的每个进程组内,应用归约 函数 computation 映射到来自每个进程的 operands 张量的值。 并生成 results 张量。

该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • cross_replica(replica_groups) 如果 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果 channel_id > 0 and use_global_device_ids = true

之后,在每个 process_group 中:

  • results...@process[result_index] = exec(schedule),适用于某些二元树 schedule,其中: <ph type="x-smartling-placeholder">
      </ph>
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是一种由实现定义的二元树,其顺序 遍历为 to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))

输入

标签 名称 类型 限制条件
(I1) operands 可变数量张量或每个张量量化张量 (C5)、(C6)
(I2) replica_groups si64 类型的一维张量常量的可变数量 (C1 - C3)
(I3) channel_id si64 类型的常量 (C4)
(I4) use_global_device_ids i1 类型的常量 (C4)
(I5) computation 函数 (C5)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C6-C7)

限制条件

  • (C1) is_unique(replica_groups)
  • (C2) size(replica_groups) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C3) 0 <= replica_groups < size(replica_groups)
  • (C4) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C5) computation 的类型为 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C6) shape(results...) = shape(operands...)
  • (C7) element_type(results...) = E

示例

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

更多示例

all_to_all

语义

all_to_all

在 StableHLO 进程网格的每个进程组中,将 operands 张量沿 split_dimension 分成多个部分,将拆分后的 部分将零散的部分串联起来 concat_dimension 并生成 results 张量。 该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • 如果 channel_id <= 0,则为 cross_replica(replica_groups)
  • 如果 channel_id > 0,则为 cross_partition(replica_groups)

之后,在每个 process_group 中:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension)process_group中的所有sender
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group],其中 receiver_index = process_group.index(receiver)
  • results...@process = concatenate(scattered_parts...@process, concat_dimension)

输入

标签 名称 类型 限制条件
(I1) operands 可变数量张量或每个张量量化张量 (C1-C3)、(C9)
(I2) split_dimension si64 类型的常量 (C1)、(C2)、(C9)
(I3) concat_dimension si64 类型的常量 (C3)、(C9)
(I4) split_count si64 类型的常量 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups si64 类型的二维张量常量 (C5-C8)
(I6) channel_id si64 类型的常量

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C9)

限制条件

  • (C1) 0 <= split_dimension < rank(operands...)
  • (C2) dim(operands..., split_dimension) % split_count = 0
  • (C3) 0 <= concat_dimension < rank(operands...)
  • (C4) 0 < split_count
  • (C5) is_unique(replica_groups)
  • (C6) size(replica_groups) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C7) 0 <= replica_groups < size(replica_groups)
  • (C8) dim(replica_groups, 1) = split_count
  • (C9) type(results...) = type(operands...),除非 split_dimension != concat_dimension: <ph type="x-smartling-placeholder">
      </ph>
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count

示例

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

更多示例

语义

对两个张量 lhsrhs 执行元素级 AND 运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 AND。
  • 对于整数:按位 AND。

输入

标签 名称 类型 限制条件
(I1) lhs 布尔值或整数类型的张量 (C1)
(I2) rhs 布尔值或整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

更多示例

atan2

语义

lhsrhs 张量执行元素级 atan2 运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 atan2
  • 对于复数:复数 atan2。
  • 对于量化类型:dequantize_op_quantize(atan2, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 浮点型、复杂类型或每张量量化张量 (C1)
(I2) rhs 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

更多示例

batch_norm_grad

语义

计算 batch_norm_training 反向传播的多个输入的梯度 来自 grad_output,并生成 grad_operandgrad_scalegrad_offset 张量。更正式地说,此操作可以表示为 现有的 StableHLO 操作,如下所示:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

对于量化类型, dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1-C3)、(C5)
(I2) scale 浮点或每张量量化类型的一维张量 (C2)、(C4)、(C5)
(I3) mean 浮点或每张量量化类型的一维张量 (C2)、(C4)
(I4) variance 浮点或每张量量化类型的一维张量 (C2)、(C4)
(I5) grad_output 浮点类型的张量或每张量量化张量 (C2)、(C3)
(I6) epsilon f32 类型的常量
(I7) feature_index si64 类型的常量 (C1)、(C5)

Outputs

名称 类型 限制条件
grad_operand 浮点类型的张量或每张量量化张量 (C2)、(C3)
grad_scale 浮点或每张量量化类型的一维张量 (C2)、(C4)
grad_offset 浮点或每张量量化类型的一维张量 (C2)、(C4)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offset 具有相同的 baseline_element_type
  • (C3) operandgrad_outputgrad_operand 具有相同的形状。
  • (C4) scalemeanvariancegrad_scalegrad_offset 具有 形状相同。
  • (C5) size(scale) = dim(operand, feature_index)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

语义

针对所有维度对 operand 张量进行归一化,但 feature_index 维度,并生成一个 result 张量。更正式地说, 操作可以表示为对现有 StableHLO 操作的分解 如下所示:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

对于量化类型, dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1 - C7)
(I2) scale 浮点或每张量量化类型的一维张量 (C2)、(C3)
(I3) offset 浮点或每张量量化类型的一维张量 (C2)、(C4)
(I4) mean 浮点或每张量量化类型的一维张量 (C5)
(I5) variance 浮点或每张量量化类型的一维张量 (C2)、(C6)
(I6) epsilon f32 类型的常量
(I7) feature_index si64 类型的常量 (C1)、(C3-C6)

Outputs

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C2)、(C7)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetmeanvarianceresult 具有 相同的 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(mean) = dim(operand, feature_index)
  • (C6) size(variance) = dim(operand, feature_index)
  • (C7) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

语义

计算除feature_index之外的所有维度的平均值和方差 对 operand 张量进行归一化,从而生成 outputbatch_meanbatch_var 张量。更正式地说,此操作可以表示为 使用 Python 语法分解为现有的 StableHLO 操作, 如下:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

对于量化类型, dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)
(I2) scale 量化的浮点数或每张量的一维张量 (C2)、(C3)
(I3) offset 量化的浮点数或每张量的一维张量 (C2)、(C4)
(I4) epsilon f32 类型的常量 (C1)、(C3-C6)
(I5) feature_index si64 类型的常量 (C1)、(C3-C6)

Outputs

名称 类型 限制条件
output 浮点类型的张量或每张量量化张量 (C7)
batch_mean 量化的浮点数或每张量的一维张量 (C2)、(C5)
batch_var 量化的浮点数或每张量的一维张量 (C2)、(C6)

限制条件

  • (C1) 0 <= feature_index < rank(operand)
  • (C2) operandscaleoffsetbatch_meanbatch_varoutput 同一个 baseline_element_type
  • (C3) size(scale) = dim(operand, feature_index)
  • (C4) size(offset) = dim(operand, feature_index)
  • (C5) size(batch_mean) = dim(operand, feature_index)
  • (C6) size(batch_var) = dim(operand, feature_index)
  • (C7) baseline_type(output) = baseline_type(operand)

示例

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

语义

operand 张量执行 Bitcast 操作,并生成一个 result 张量 其中,整个operand张量的位使用 result 张量的类型。

更正式地说,鉴于 E = element_type(operand)E' = element_type(result), 和 R = rank(operand)

  • 如果为 num_bits(E') < num_bits(E)bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • 如果为 num_bits(E') > num_bits(E)bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • 如果为 num_bits(E') = num_bits(E)bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits 返回给定值的内存中表示法及其行为 是由实现定义的,因为张量的确切表示形式是 实现定义的,元素类型的确切表示形式为 实现定义的。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C2)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1 - C2)

限制条件

  • (C1) 假定有 E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand)
    • 如果为 num_bits(E') = num_bits(E),则为 shape(result) = shape(operand)
    • 如果为 num_bits(E') < num_bits(E)
    • rank(result) = R + 1
    • 针对所有0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(result, R) * num_bits(E') = num_bits(E)
    • 如果为 num_bits(E') > num_bits(E)
    • rank(result) = R - 1
    • 针对所有0 <= i < Rdim(result, i) = dim(operand, i)
    • dim(operand, R - 1) * num_bits(E) = num_bits(E')
  • (C2) 如果 is_complex(operand) or is_complex(result),则: is_complex(operand) and is_complex(result)

示例

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

更多示例

broadcast_in_dim

语义

通过复制数据来扩展输入张量的维度和/或秩 并生成一个 result 张量。operand更正式地说, result[result_index] = operand[operand_index],其中所有d axes(operand):

  • 如果 dim(operand, d) = 1,则为 operand_index[d] = 0
  • 否则为 operand_index[d] = result_index[broadcast_dimensions[d]]

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C2)、(C5-C6)
(I2) broadcast_dimensions si64 类型的一维张量常量 (C2-C6)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3)、(C5-C6)

限制条件

  • (C1) element_type(result) 由以下公式计算得出: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand)(如果 !is_per_axis_quantized(operand))。
    • element_type(operand),但 quantization_dimension(operand)scales(operand)zero_points(operand) 可能不同于 quantization_dimension(result)scales(result)zero_points(result) 否则。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 对于 axes(operand) 中的所有 d: <ph type="x-smartling-placeholder">
      </ph>
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result): <ph type="x-smartling-placeholder">
      </ph>
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果为 dim(operand, quantization_dimension(operand)) = 1,则 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))

示例

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多示例

场景

语义

通过 branches 只执行一个函数来生成输出 具体取决于 index 的值。更正式地说,result = selected_branch() 其中:

  • 如果 0 <= index < size(branches),则为 selected_branch = branches[index]
  • 否则为 selected_branch = branches[-1]

输入

标签 名称 类型 限制条件
(I1) index si32 类型的 0 维张量
(I2) branches 可变数量的函数 (C1 - C4)

Outputs

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C4)

限制条件

  • (C1) 0 < size(branches)
  • (C2) input_types(branches...) = []
  • (C3) same(output_types(branches...))
  • (C4) type(results...) = output_types(branches[0])

示例

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

更多示例

cbrt

语义

operand 张量执行元素级立方根运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 rootn(x, 3)
  • 对于复数:复数立方根。
  • 对于量化类型:dequantize_op_quantize(cbrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

更多示例

ceil

语义

operand 张量执行元素级 ceil,并生成 result 张量。 实现来自 IEEE-754 的 roundToIntegralTowardPositive 操作 规范对于量化类型, dequantize_op_quantize(ceil, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

更多示例

Cholesky

语义

计算一批矩阵的 Cholesky 分解。

更正式地说,对于 index_space(result) 中的所有 iresult[i0, ..., iR-3, :, :] 是 Cholesky 分解 a[i0, ..., iR-3, :, :],采用下三角形中的任意一个 (如果 lowertrue)或上三角形(如果 lowerfalse)矩阵。 相反三角形中的输出值,即严格的上方三角形或 则由实现定义。

如果存在 i,其中输入矩阵不是埃尔米特正定数 则行为是未定义的。

对于量化类型, dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))

输入

标签 名称 类型 限制条件
(I1) a 浮点型、复杂类型或每张量量化张量 (C1 - C3)
(I2) lower i1 类型的 0 维张量常量

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(a) = baseline_type(result)
  • (C2) 2 <= rank(a)
  • (C3) dim(a, -2) = dim(a, -1)

示例

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

限制取值范围

语义

operand 张量的每个元素限制在最小值和最大值之间 值并生成一个 result 张量。更正式地说,result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), 其中,min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index]。对于量化类型 执行 dequantize_op_quantize(clamp, min, operand, max, type(result))

对复数施加排序涉及令人惊讶的语义, 所以将来我们计划取消对复数的支持 (#560)。

输入

标签 名称 类型 限制条件
(I1) min 张量或每张量量化张量 (C1)、(C3)
(I2) operand 张量或每张量量化张量 (C1 - C4)
(I3) max 张量或每张量量化张量 (C2)、(C3)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C4)

限制条件

  • (C1) rank(min) = 0 or shape(min) = shape(operand)
  • (C2) rank(max) = 0 or shape(max) = shape(operand)
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4) baseline_type(operand) = baseline_type(result)

示例

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

更多示例

collective_broadcast

语义

在 StableHLO 进程网格中的每个进程组内,发送 operand 张量从源进程传输到目标进程,并生成一个 result 张量。

该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • 如果 channel_id <= 0,则为 cross_replica(replica_groups)
  • 如果 channel_id > 0,则为 cross_partition(replica_groups)

之后,result@process 将通过以下方式提供:

  • operand@process_groups[i, 0](如果存在 i,使得进程 在 process_groups[i] 中。
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) 否则。

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C3)
(I2) replica_groups si64 类型的一维张量常量的可变数量 (C1)、(C2)
(I3) channel_id si64 类型的常量

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C3)

限制条件

  • (C1) is_unique(replica_groups)
  • (C2) 0 <= replica_groups < N,其中 N 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C3) type(result) = type(operand)

示例

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

语义

在 StableHLO 进程网格中的每个进程组内,发送 从源进程到目标进程的 operand 张量,并生成一个 result 张量。

该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • 如果 channel_id <= 0,则为 cross_replica(source_target_pairs)
  • 如果 channel_id > 0,则为 cross_partition(source_target_pairs)

之后,result@process 将通过以下方式提供:

  • operand@process_groups[i, 0](如果存在 i,使得 process_groups[i, 1] = process
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) 否则。

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C5)
(I2) source_target_pairs si64 类型的二维张量常量 (C1 - C4)
(I3) channel_id si64 类型的常量

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) dim(source_target_pairs, 1) = 2
  • (C2) is_unique(source_target_pairs[:, 0])
  • (C3) is_unique(source_target_pairs[:, 1])
  • (C4) 0 <= source_target_pairs < N,其中 N 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_partition,则为 num_partitions
  • (C5) type(result) = type(operand)

示例

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

更多示例

比较

语义

根据以下内容对 lhsrhs 张量执行元素级比较: comparison_directioncompare_type,并生成一个 result 张量。

comparison_directioncompare_type 的值包含以下 语义:

对于布尔值和整数元素类型:

  • EQlhs = rhs
  • NElhs != rhs
  • GElhs >= rhs
  • GTlhs > rhs
  • LElhs <= rhs
  • LTlhs < rhs

对于具有 compare_type = FLOAT 的浮点元素类型,相应操作会实现 以下 IEEE-754 操作:

  • EQcompareQuietEqual
  • NEcompareQuietNotEqual
  • GEcompareQuietGreaterEqual
  • GTcompareQuietGreater
  • LEcompareQuietLessEqual
  • LTcompareQuietLess

对于包含 compare_type = TOTALORDER 的浮点元素类型,相应操作 结合使用 totalOrdercompareQuietEqual 运算 IEEE-754。

对于复杂元素类型,(real, imag) 对的字典顺序比较如下: (使用提供的 comparison_directioncompare_type 执行) 对复数施加排序涉及令人惊讶的语义, 所以将来我们计划取消对复数的支持 当 comparison_directionGEGTLELT 时 (#560)。

适用于量化类型。执行 dequantize_compare(lhs, rhs, comparison_direction)

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1 - C3)
(I2) rhs 张量或每张量量化张量 (C1 - C2)
(I3) comparison_direction EQNEGEGTLELT 的枚举
(I4) compare_type FLOATTOTALORDERSIGNEDUNSIGNED 的枚举 (C3)

Outputs

名称 类型 限制条件
result 布尔值类型的张量 (C2)

限制条件

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2) shape(lhs) = shape(rhs) = shape(result)
  • (C3) compare_type 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_signed_integer(element_type(lhs)),则为 SIGNED
    • 如果 is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)),则为 UNSIGNED
    • FLOATTOTALORDER(如果 is_float(element_type(lhs)))。
    • 如果 is_complex(element_type(lhs)),则为 FLOAT

示例

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

更多示例

复杂

语义

从一对实数和 虚值 lhsrhs,并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs f32f64 类型的张量 (C1 - C3)
(I2) rhs f32f64 类型的张量 (C1)

Outputs

名称 类型 限制条件
result 复杂类型的张量 (C2)、(C3)

限制条件

  • (C1) type(lhs) = type(rhs)
  • (C2) shape(result) = shape(lhs)
  • (C3) element_type(result) 的类型为 complex<E>,其中 E = element_type(lhs)

示例

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

更多示例

复合型

语义

封装由其他 StableHLO 操作组成(组合)的操作, 获取 inputscomposite_attributes,并生成 results。通过 操作的语义通过 decomposition 属性实现。通过 composite 操作可替换为其分解,而无需更改程序 语义信息。当内嵌分解不能提供相同的 操作语义,最好使用 custom_call

version 字段(默认为 0)用于表示复合的 语义变化。

输入

标签 名称 类型
(I1) inputs 值的可变数量
(I2) name string 类型的常量
(I3) composite_attributes 属性字典
(I4) decomposition string 类型的常量
(I5) version si32 类型的常量

Outputs

名称 类型
results 值的可变数量

限制条件

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

示例

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

更多示例

concatenate

语义

按照给定顺序沿 dimension 维度串联 inputs 并生成一个 result 张量。更正式地说, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1],其中:

  1. id = d0 + ... + dk-1 + kd
  2. d 等于 dimensiond0 等是第 d 个维度大小 共 inputs 个。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1 - C6)
(I2) dimension si64 类型的常量 (C2)、(C4)、(C6)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C5-C6)

限制条件

  • (C1) same(element_type(inputs...))
  • (C2) same(shape(inputs...))dim(inputs..., dimension) 除外)。
  • (C3) 0 < size(inputs)
  • (C4) 0 <= dimension < rank(inputs[0])
  • (C5) element_type(result) = element_type(inputs[0])
  • (C6) shape(result) = shape(inputs[0]),以下情形除外: <ph type="x-smartling-placeholder">
      </ph>
    • dim(result, dimension) = dim(inputs[0], dimension) + ...

示例

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

更多示例

常量

语义

从常量 value 生成 output 张量。

输入

标签 名称 类型 限制条件
(I1) value 常量 (C1)

Outputs

名称 类型 限制条件
output 张量或量化张量 (C1)

限制条件

  • (C1) type(value) = type(output)

示例

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

更多示例

转化

语义

operand 张量,并生成一个 result 张量。

对于 boolean-to-any-supported-type 转化,值 false 为 转换为 0,值 true 将转换为 1。对于 any-supported-type-to-boolean转换,零值都会转换为 false 和非零值将转换为 true。请参阅下文,了解 适用于复杂类型。

对于涉及整数到整数、整数到浮点数的转换floating-point-to-floating-point(如果来源值可以正好) 以目的地类型表示,则结果值是 表示。否则,行为将处于待定状态 (#180)。

对于涉及floating-point-to-integer的转换,小数部分为 被截断。如果截断的值无法在目标类型中表示, 行为待定 (#180)。

涉及复杂到复杂的转化遵循相同的行为, floating-point-to-floating-point 转换,用于将实数和 虚部。

对于“复杂到任何其他类型”和“任何其他类型到复杂类型”complex-to-any-other-typecomplex-to-any-other-type转化, 源虚值被忽略,或目标虚值是 分别为 0。实际部分的转换跟随 浮点数转换。

原则上,此运算可以表示反量化( 量化张量到正则张量)、量化(从正则 张量到量化张量)和重新量化(量化 张量),但目前我们有专门的运算 - uniform_dequantize 用于第一个用例,uniform_quantize 用于 应用场景。将来,这两个操作可能会合并 转换为 convert (#1576)。

输入

标签 名称 类型 限制条件
(I1) operand 张量 (C1)

Outputs

名称 类型 限制条件
result 张量 (C1)

限制条件

  • (C1) shape(operand) = shape(result)

示例

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

更多示例

卷积

语义

计算 lhs 窗口与 rhs 切片之间的点积,并生成 result。下图显示了如何根据以下内容计算 result 中的元素: lhsrhs

卷积

更正式地说,请考虑以下根据 lhs 对输入进行重构 为了能够表示 lhs 窗口:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)

此重新构图使用以下辅助函数:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1],其中 j[d] = i[permutation[d]]

如果 feature_group_count = 1batch_group_count = 1,则对所有 index_space(dim(result, output_spatial_dimensions...))output_spatial_indexresult[result_shape(:, output_spatial_index, :)] = dot_product,其中:

  • padding_value = constant(0, element_type(lhs))
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。 此功能似乎未被使用,因此将来我们计划移除 (#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])

如果为 feature_group_count > 1

  • lhses = split(lhs, feature_group_count, input_feature_dimension)
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

如果为 batch_group_count > 1

  • lhses = split(lhs, batch_group_count, input_batch_dimension)
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

对于量化类型,执行 dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))

对于混合量化类型,请执行 hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs)

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)、(C10-C11)、(C14) (C25)、(C27-C28)、(C31-C32)、(C34)
(I2) rhs 张量或量化张量 (C1)、(C14-C16)、(C25)、(C27-C29)、(C31-C34)
(I3) window_strides si64 类型的一维张量常量 (C2-C3)、(C25)
(I4) padding si64 类型的二维张量常量 (C4)、(C25)
(I5) lhs_dilation si64 类型的一维张量常量 (C5-C6)、(C25)
(I6) rhs_dilation si64 类型的一维张量常量 (C7-C8)、(C25)
(I7) window_reversal i1 类型的一维张量常量 (C9)
(I8) input_batch_dimension si64 类型的常量 (C10)、(C13)、(C25)
(I9) input_feature_dimension si64 类型的常量 (C11)、(C13-C14)
(I10) input_spatial_dimensions si64 类型的一维张量常量 (C12)、(C13)、(C25)
(I11) kernel_input_feature_dimension si64 类型的常量 (C14)、(C18)
(I12) kernel_output_feature_dimension si64 类型的常量 (C15-C16)、(C18)、(C25)、(C29)
(I13) kernel_spatial_dimensions si64 类型的一维张量常量 (C17-C18)、(C25)
(I14) output_batch_dimension si64 类型的常量 (C20)、(C25)
(I15) output_feature_dimension si64 类型的常量 (C20)、(C25)、(C30)
(I16) output_spatial_dimensions si64 类型的一维张量常量 (C19-C20)、(C25)
(I17) feature_group_count si64 类型的常量 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 类型的常量 (C10)、(C15)、(C22)、(C23)、(C25)
(I19) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C24)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C25-C28)、(C30)、(C32-34)

限制条件

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 假设 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) 假设 kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 假设 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 result_dim = output_batch_dimension,则为 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,则为 dim(rhs, kernel_output_feature_dimension)
    • 否则为 num_windows,其中:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果运算使用非量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果运算使用量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs), 之后价格为 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果 is_per_axis_quantized(result),则: quantization_dimension(result) = output_feature_dimension
    • 如果为 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果为 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result)

示例

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

更多示例

余弦

语义

operand 张量执行元素级余弦运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 cos
  • 对于复数:复余弦。
  • 对于量化类型:dequantize_op_quantize(cosine, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

更多示例

count_leading_zeros

语义

operand 中的前导零位的数量执行元素级计数 张量并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) operand 整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

更多示例

custom_call

语义

封装一个由实现定义的操作 call_target_name,该操作将 inputscalled_computations,并生成 resultshas_side_effect, backend_configapi_version 可用于提供额外的 实现定义的元数据

目前,此操作包含相当无序的 元数据,用于反映其对应操作在 XLA 编译器未来,我们计划统一这些元数据 (#741)。

输入

标签 名称 类型
(I1) inputs 值的可变数量
(I2) call_target_name string 类型的常量
(I3) has_side_effect i1 类型的常量
(I4) backend_config 类型为 string 或属性字典的常量
(I5) api_version si32 类型的常量
(I6) called_computations string 类型的常量可变数量

Outputs

名称 类型
results 值的可变数量

示例

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

语义

对被除数 lhs 和除数 rhs 张量执行元素级除法,并且 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数除法,可生成具有任意值的代数商 已舍弃小数部分。
  • 对于浮点数:IEEE-754 中的 division
  • 对于复数:复数除法。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(divide, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点、复杂类型或每张量量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

更多示例

dot_general

语义

计算 lhs 切片与 rhs 切片之间的点积,并生成 result 张量。

更正式的名称是 result[result_index] = dot_product,其中:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
  • result_batching_index + result_lhs_index + result_rhs_index = result_index 其中,size(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

对于量化类型,执行 dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))

对于混合量化类型,请执行 hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs)

precision_config 用于控制 加速器后端上的计算。可以是以下某一项(在 这些枚举值的语义未指定,但我们 我们计划 #755):

  • DEFAULT:计算速度最快,但近似于 原始编号。
  • HIGH:计算速度较慢,但近似于 原始编号。
  • HIGHEST:计算速度最慢,但近似于 原始编号。

DotAlgorithm 定义实现算法时使用的主要属性 点运算,这也定义了精度。如果算法属性 字段,那么 precision_config 必须为 DEFAULTDotAlgorithms 则没有默认值,因为默认参数是实现 。因此,所有点算法字段均可设置为 None,以指定 空点算法,该算法将使用 precision_config 值。

DotAlgorithm 字段包括:

  • lhs_precision_typerhs_precision_type,即 LHS 和 运算的 RHS 舍入。精度类型独立于 输入和输出的存储类型。
  • accumulation_type:用于累加的精度。
  • lhs_component_countrhs_component_countnum_primitive_operations 这种算法会将 LHS 和/或 RHS 分解为 并执行多个“原语”对它们执行点运算 值 - 通常是为了模拟更高的精度(例如 利用 bfloat16 人工智能数据类型进行更精确的计算: bf16_6x tf32_3x 等)。对于不分解的算法,以下值 应设为 1
  • allow_imprecise_accumulation:用于指定累积的精度是否较低 允许用于某些步骤(例如 CUBLASLT_MATMUL_DESC_FAST_ACCUM)。

DotAlgorithm 属性示例:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

具体支持哪些组合由实现决定。在 我们无法保证每个算法都支持每种算法, 加速器类型。如果给定算法不是 应引发错误,而不是回退到 StableHLO 验证将提供尽力验证, 可阻止已知在任何硬件上支持的算法。

请参阅 xla_data.proto > Algorithm 部分支持的算法值。工单 #2483 显示创建 关于后端支持的算法的集中式文档。

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C5-C6)、(C9-C10)、(C12-C14)、(C17-C18)、(C20)
(I2) rhs 张量或量化张量 (C7-C10)、(C12-C20)
(I3) lhs_batching_dimensions si64 类型的一维张量常量 (C1)、(C3)、(C5)、(C9)、(C12)
(I4) rhs_batching_dimensions si64 类型的一维张量常量 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions si64 类型的一维张量常量 (C2)、(C3)、(C6)、(C10)
(I6) rhs_contracting_dimensions si64 类型的一维张量常量 (C2)、(C4)、(C8)、(C10)、(C16)
(I7) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C11)、(C21)
(I8) lhs_precision_type FloatType 或 TensorFloat32 (C21)
(I9) rhs_precision_type FloatType 或 TensorFloat32 (C21)
(I10) accumulation_type FloatType 或 TensorFloat32 (C21)
(I11) lhs_component_count si32 类型的常量 (C21)、(C22)
(I12) rhs_component_count si32 类型的常量 (C21)、(C23)
(I13) num_primitive_operations si32 类型的常量 (C21)、(C24)
(I14) allow_imprecise_accumulation bool 类型的常量 (C21)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C12)、(C14)、(C18-C20)

限制条件

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs)
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs)
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11) size(precision_config) = 2
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • 如果运算使用非量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C13) element_type(lhs) = element_type(rhs)
  • 如果运算使用量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C15) zero_points(rhs) = 0
    • (C16) 如果 is_per_axis_quantized(rhs),则 quantization_dimension(rhs)不在rhs_contracting_dimensions中。
    • 如果为 is_quantized(lhs)
    • (C17) storage_type(lhs) = storage_type(rhs)
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C19) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果为 !is_quantized(lhs)
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result)
  • 如果为 !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation): <ph type="x-smartling-placeholder">
      </ph>
    • (C21) precision_config... = DEFAULT
    • (C22) 0 < lhs_component_count
    • (C23) 0 < rhs_component_count
    • (C24) 0 < num_primitive_operations

示例

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

更多示例

dynamic_broadcast_in_dim

语义

此操作在功能上与 broadcast_in_dim 操作,但结果形状是通过 output_dimensions 动态指定的。

该操作还接受可选属性 known_expanding_dimensionsknown_non_expanding_dimensions 来表达关于维度展开行为的静态知识。 如果未指定,系统会假定所有尺寸均可展开。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1-C2)、(C5-C6)、(C9)
(I2) output_dimensions 整数类型的一维张量 (C7)
(I3) broadcast_dimensions 整数类型的一维常量张量 (C2-C6)
(I4) known_expanding_dimensions 整数类型的一维常量张量 (C8 - C9)
(I5) known_non_expanding_dimensions 整数类型的一维常量张量 (C8 - C9)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3)、(C5-C7)

限制条件

  • (C1) element_type(result) 由以下公式计算得出: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand)(如果 !is_per_axis_quantized(operand))。
    • element_type(operand),但 quantization_dimension(operand)scales(operand)zero_points(operand) 可能不同于 quantization_dimension(result)scales(result)zero_points(result) 否则。
  • (C2) size(broadcast_dimensions) = rank(operand)
  • (C3) 0 <= broadcast_dimensions < rank(result)
  • (C4) is_unique(broadcast_dimensions)
  • (C5) 对于 axes(operand) 中的所有 d: <ph type="x-smartling-placeholder">
      </ph>
    • dim(operand, d) = 1
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6) 如果 is_per_axis_quantized(result): <ph type="x-smartling-placeholder">
      </ph>
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • 如果为 dim(operand, quantization_dimension(operand)) = 1,则 scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
  • (C7) size(output_dimensions) = rank(result)
  • (C8) is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
  • (C9) 0 <= known_expanding_dimensions < rank(operand)
  • (C10) 0 <= known_non_expanding_dimensions < rank(operand)

示例

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

更多示例

dynamic_conv

语义

此操作在功能上与 卷积 操作,但内边距是通过 padding 动态指定的。

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)、(C10-C11)、(C14) (C25)、(C26-C27)、(C30-C31)、(C33)
(I2) rhs 张量或量化张量 (C1)、(C14-C16)、(C26-C28)、(C30-C33)
(I3) padding 整数类型的二维张量 (C4)
(I4) window_strides si64 类型的一维张量常量 (C2-C3)
(I5) lhs_dilation si64 类型的一维张量常量 (C5-C6)
(I6) rhs_dilation si64 类型的一维张量常量 (C7-C8)
(I7) window_reversal i1 类型的一维张量常量 (C9)
(I8) input_batch_dimension si64 类型的常量 (C10)、(C13)
(I9) input_feature_dimension si64 类型的常量 (C11)、(C13-C14)
(I10) input_spatial_dimensions si64 类型的一维张量常量 (C12)、(C13)
(I11) kernel_input_feature_dimension si64 类型的常量 (C14)、(C18)
(I12) kernel_output_feature_dimension si64 类型的常量 (C15-C16)、(C18)、(C28)
(I13) kernel_spatial_dimensions si64 类型的一维张量常量 (C17-C18)
(I14) output_batch_dimension si64 类型的常量 (C20)
(I15) output_feature_dimension si64 类型的常量 (C20)、(C29)
(I16) output_spatial_dimensions si64 类型的一维张量常量 (C19 至 C20)
(I17) feature_group_count si64 类型的常量 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 类型的常量 (C10)、(C15)、(C22)、(C23)
(I19) precision_config DEFAULTHIGHHIGHEST 枚举的可变数量 (C24)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C25-C27)、(C29)、(C31-C33)

限制条件

  • (C1) N = rank(lhs) = rank(rhs)
  • (C2) size(window_strides) = N - 2
  • (C3) 0 < window_strides
  • (C4) shape(padding) = [N - 2, 2]
  • (C5) size(lhs_dilation) = N - 2
  • (C6) 0 < lhs_dilation
  • (C7) size(rhs_dilation) = N - 2
  • (C8) 0 < rhs_dilation
  • (C9) size(window_reversal) = N - 2
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12) size(input_spatial_dimensions) = N - 2
  • (C13) 假设 input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17) size(kernel_spatial_dimensions) = N - 2
  • (C18) 假设 kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19) size(output_spatial_dimensions) = N - 2
  • (C20) 假设 output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21) 0 < feature_group_count
  • (C22) 0 < batch_group_count
  • (C23) feature_group_count = 1 or batch_group_count = 1
  • (C24) size(precision_config) = 2
  • (C25) dim(result, result_dim) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 result_dim = output_batch_dimension,则为 dim(lhs, input_batch_dimension) / batch_group_count
    • 如果 result_dim = output_feature_dimension,则为 dim(rhs, kernel_output_feature_dimension)
    • 否则为 num_windows,其中:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26) rank(result) = N
  • 如果运算使用非量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result)
  • 如果运算使用量化张量: <ph type="x-smartling-placeholder">
      </ph>
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29) 如果 is_per_axis_quantized(rhs), 之后价格为 quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30) 如果 is_per_axis_quantized(result),则: quantization_dimension(result) = output_feature_dimension
    • 如果为 is_quantized(lhs)
    • (C31) storage_type(lhs) = storage_type(rhs)
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) 如果 is_per_tensor_quantized(rhs),则 is_per_tensor_quantized(result)
    • 如果为 !is_quantized(lhs)
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result)

示例

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

更多示例

dynamic_gather

语义

此操作在功能上与 收集 操作,并将 slice_sizes 动态指定为值。

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C7)、(C10-C12)、(C14)
(I2) start_indices 整数类型的张量 (C2)、(C3)、(C13)
(I3) slice_sizes 整数类型的一维张量 (C8)、(C11-C13)
(I4) offset_dims si64 类型的一维张量常量 (C1)、(C4-C5)、(C13)
(I5) collapsed_slice_dims si64 类型的一维张量常量 (C1)、(C6-C8)、(C13)
(I6) start_index_map si64 类型的一维张量常量 (C3)、(C9)、(C10)
(I7) index_vector_dim si64 类型的常量 (C2)、(C3)、(C13)
(I8) indices_are_sorted i1 类型的常量

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C5)、(C13-C14)

限制条件

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7) 0 <= collapsed_slice_dims < rank(operand)
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1
  • (C9) is_unique(start_index_map)
  • (C10) 0 <= start_index_map < rank(operand)
  • (C11) size(slice_sizes) = rank(operand)
  • (C12) 0 <= slice_sizes <= shape(operand)
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes),其中: <ph type="x-smartling-placeholder">
      </ph>
    • batch_dim_sizes = shape(start_indices),只不过尺寸尺寸 与 index_vector_dim 对应的 start_indices 未包含在内。
    • offset_dim_sizes = shape(slice_sizes),只不过尺寸 与 collapsed_slice_dims 对应的 slice_sizes 中不包括在内。
    • combinebatch_dim_sizes 放置在与 batch_dims 对应的轴上, offset_dims 所对应的轴上的 offset_dim_sizes
  • (C14) element_type(operand) = element_type(result)

示例

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

更多示例

dynamic_iota

语义

此操作在功能上与 iota 操作,但结果形状是通过 output_shape 动态指定的。

输入

标签 名称 类型 限制条件
(I1) output_shape 整数类型的一维张量 (C1)、(C2)
(I2) iota_dimension si64 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C2)

限制条件

  • (C1) 0 <= iota_dimension < size(output_shape)
  • (C2) rank(result) = size(output_shape)

示例

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

更多示例

dynamic_pad

语义

此操作在功能上与 键盘 操作,但包含 edge_padding_lowedge_padding_highinterior_padding 动态指定为值。

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C2)、(C4)
(I2) padding_value 0 维张量或每张量量化张量 (C1)
(I3) edge_padding_low 整数类型的一维张量 (C1)、(C4)
(I4) edge_padding_high 整数类型的一维张量 (C1)、(C4)
(I5) interior_padding 整数类型的一维张量 (C2-C4)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C3-C6)

限制条件

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

示例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多示例

dynamic_reshape

语义

此操作在功能上与 重塑 操作,但结果形状是通过 output_shape 动态指定的。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C3)
(I2) output_shape 整数类型的一维张量 (C4)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1 - C4)

限制条件

  • (C1) element_type(result) 由以下公式计算得出: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand)(如果 !is_per_axis_quantized(operand))。
    • element_type(operand),但 quantization_dimension(operand) 和 否则,quantization_dimension(result) 可能会有所不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
  • (C4) size(output_shape) = rank(result)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

更多示例

dynamic_slice

语义

使用动态计算的起始索引从 operand 中提取 Slice 并生成一个 result 张量。start_indices 包含 每个维度的切片(可能会发生调整)以及slice_sizes 包含每个维度切片的大小。更正式地说, result[result_index] = operand[operand_index],其中:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
  • operand_index = adjusted_start_indices + result_index

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C2)、(C4)
(I2) start_indices 整数类型的 0 维张量的可变数量 (C2)、(C3)
(I3) slice_sizes si64 类型的一维张量常量 (C2)、(C4)、(C5)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)、(C5)

限制条件

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3) same(type(start_indices...))
  • (C4) 0 <= slice_sizes <= shape(operand)
  • (C5) shape(result) = slice_sizes

示例

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

更多示例

dynamic_update_slice

语义

生成一个 result 张量,该张量等于 operand 张量,但存在 从 start_indices 开始的 Slice 会更新为 update 中的值。 更正式地说,result[result_index] 的定义如下:

  • 如果 0 <= update_index < shape(update),则为 update[update_index],其中: <ph type="x-smartling-placeholder">
      </ph>
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
    • update_index = result_index - adjusted_start_indices
  • 否则为 operand[result_index]

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1-C4)、(C6)
(I2) update 张量或每张量量化张量 (C2)、(C3)、(C6)
(I3) start_indices 整数类型的 0 维张量的可变数量 (C4)、(C5)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) type(operand) = type(result)
  • (C2) element_type(update) = element_type(operand)
  • (C3) rank(update) = rank(operand)
  • (C4) size(start_indices) = rank(operand)
  • (C5) same(type(start_indices...))
  • (C6) 0 <= shape(update) <= shape(operand)

示例

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

更多示例

指数函数

语义

operand 张量执行元素级指数运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 exp
  • 对于复数:复数指数。
  • 对于量化类型: dequantize_op_quantize(exponential, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

更多示例

exponential_minus_one

语义

operand 张量执行元素级指数减一个运算,并 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 expm1
  • 对于复数:复数指数减一。
  • 对于量化类型: dequantize_op_quantize(exponential_minus_one, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

更多示例

FFT

语义

对实数和复数进行正向和逆向福里叶转换 输入/输出。

fft_type 是以下值之一:

  • FFT:转发复杂到复杂 FFT。
  • IFFT:复数到复数的 FFT 的反函数。
  • RFFT:转发实数到复数 FFT。
  • IRFFT:实数到复数的反函数 FFT(即取复数,返回实数)。

更正式地说,假设函数 fft 接受以下公式的一维张量: 作为输入,会生成与 输出并计算离散 Furier 转换:

对于 fft_type = FFTresult 定义为一系列 L L = size(fft_length) 的计算。例如,对于 L = 3

  • result1[i0, ..., :] = fft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

此外,假设函数 ifft 具有相同的类型签名和 计算 fft 的反函数:

对于 fft_type = IFFTresult 定义为计算的逆 价格为 fft_type = FFT。例如,对于 L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :])

此外,假设函数 rfft 接受以下公式的一维张量: 浮点类型可以生成复杂类型的一维张量, 相同的浮点语义,工作原理如下:

  • rfft(real_operand) = truncated_result,其中
  • complex_operand... = (real_operand..., 0.0)
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]

(在针对实操作数计算离散 Furier 转换时,第一个 结果的 N/2 + 1 元素明确定义结果的其余部分, 因此 rfft 的结果会被截断,以避免计算冗余元素)。

对于 fft_type = RFFTresult 定义为一系列 L L = size(fft_length) 的计算。例如,对于 L = 3

  • result1[i0, ..., :] = rfft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

最后,给定具有相同类型签名的函数 irfft, 计算 rfft 的反函数:

对于 fft_type = IRFFTresult 定义为计算的逆 价格为 fft_type = RFFT。例如,对于 L = 3

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :])

输入

标签 名称 类型 限制条件
(I1) operand 浮点或复杂类型的张量 (C1)、(C2)、(C4)、(C5)
(I2) fft_type FFTIFFTRFFTIRFFT 的枚举 (C2)、(C5)
(I3) fft_length si64 类型的一维张量常量 (C1)、(C3)、(C4)

Outputs

名称 类型 限制条件
result 浮点或复杂类型的张量 (C2)、(C4)、(C5)

限制条件

  • (C1) size(fft_length) <= rank(operand)
  • (C2) operandresult 元素类型之间的关系各不相同: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 fft_type = FFTelement_type(operand)element_type(result) 具有相同的复杂类型。
    • 如果状态为 fft_type = IFFTelement_type(operand)element_type(result) 具有相同的复杂类型。
    • 如果为 fft_type = RFFT,则 element_type(operand) 为浮点类型,并且 element_type(result) 是同一浮点数的复杂类型 语义信息。
    • 如果为 fft_type = IRFFT,则 element_type(operand) 为复杂类型; element_type(result) 是同一浮点数的浮点类型 语义信息。
  • (C3) 1 <= size(fft_length) <= 3
  • (C4) 如果 operandresult 中存在一个张量 real, 浮点型,然后是 shape(real)[-size(fft_length):] = fft_length
  • (C5) shape(result) = shape(operand),以下情形除外: <ph type="x-smartling-placeholder">
      </ph>
    • 如果为 fft_type = RFFTdim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • 如果为 fft_type = IRFFTdim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

示例

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

语义

operand 张量执行元素级底价,并生成 result 张量。 实现来自 IEEE-754 的 roundToIntegralTowardNegative 操作 规范对于量化类型, dequantize_op_quantize(floor, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

更多示例

收集

语义

start_indices 中指定的偏移量收集 operand 张量的切片 并生成一个 result 张量。

下图显示了 result 中的元素如何映射到 operand。该图选择了几个示例 result 索引,并详细说明它们与哪些 operand 索引相对应。

收集

更正式地说,result[result_index] = operand[operand_index],其中:

  • batch_dims = [d for d in axes(result) and d not in offset_dims]
  • batch_index = result_index[batch_dims...]
  • start_index 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • start_indices[bi0, ..., :, ..., biN],其中 bi 是以下元素中的单个元素: 如果满足以下条件,则会将 batch_index: 插入 index_vector_dim 索引处 index_vector_dim <rank(start_indices)
    • 否则为 [start_indices[batch_index]]
  • 对于axes(operand)d_operand, <ph type="x-smartling-placeholder">
      </ph>
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) 如果 d_operand = start_index_map[d_start]
    • 否则为 full_start_index[d_operand] = 0
  • 对于axes(operand)d_operand, <ph type="x-smartling-placeholder">
      </ph>
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] 如果 d_operand = operand_batching_dims[i_batching]d_start = start_indices_batching_dims[i_batching]
    • 否则为 full_batching_index[d_operand] = 0
  • offset_index = result_index[offset_dims...]
  • full_offset_index = [oi0, ..., 0, ..., oiN],其中 oi 为个人 元素在 offset_index 中,而 0 在索引位置处插入 collapsed_slice_dimsoperand_batching_dims
  • operand_index = full_start_index + full_batching_index + full_offset_index

如果 indices_are_sortedtrue,则实现可以假定: start_indices 相对于 start_index_map 进行排序,否则 行为未定义更正式地说,对于来自 indices(result) 的所有 i1 < i2full_start_index(i1) <= full_start_index(i2)

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C8)、(C11)、(C17)、(C19-C21)、(C23)
(I2) start_indices 整数类型的张量 (C2-C3)、(C14)、(C17)、(C22)
(I3) offset_dims si64 类型的一维张量常量 (C1)、(C4-C5)、(C22)
(I4) collapsed_slice_dims si64 类型的一维张量常量 (C1)、(C6-C9)、(C22)
(I5) operand_batching_dims si64 类型的一维张量常量 (C1)、(C6)、(C10-C12)、(C16-C18)、(C22)
(I6) start_indices_batching_dims si64 类型的一维张量常量 (C13-C17)
(I7) start_index_map si64 类型的一维张量常量 (C3)、(C18-C19)
(I8) index_vector_dim si64 类型的常量 (C2-C3)、(C15)、(C22)
(I9) slice_sizes si64 类型的一维张量常量 (C9)、(C12)、(C20-C22)
(I10) indices_are_sorted i1 类型的常量

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C5)、(C22-C23)

限制条件

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
  • (C2) 0 <= index_vector_dim <= rank(start_indices)
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5) 0 <= offset_dims < rank(result)
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims)
  • (C8) 0 <= collapsed_slice_dims < rank(operand)
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1
  • (C10) is_sorted(operand_batching_dims)
  • (C11) 0 <= operand_batching_dims < rank(operand)
  • (C12) slice_sizes[operand_batching_dims...] <= 1
  • (C13) is_unique(start_indices_batching_dims)
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices)
  • (C15) index_vector_dim not in start_indices_batching_dims
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims)
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims))
  • (C19) 0 <= start_index_map < rank(operand)
  • (C20) size(slice_sizes) = rank(operand)
  • (C21) 0 <= slice_sizes <= shape(operand)
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes),其中: <ph type="x-smartling-placeholder">
      </ph>
    • batch_dim_sizes = shape(start_indices),只不过尺寸尺寸 与 index_vector_dim 对应的 start_indices 未包含在内。
    • offset_dim_sizes = slice_sizes,只不过此处的 slice_sizes(对应 collapsed_slice_dims 和 未包含 operand_batching_dims
    • combinebatch_dim_sizes 放置在与 batch_dims 对应的轴上, offset_dims 所对应的轴上的 offset_dim_sizes
  • (C23) element_type(operand) = element_type(result)

示例

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

更多示例

get_dimension_size

语义

生成 operand 的指定 dimension 的大小。更正式地说, result = dim(operand, dimension)。语义学仅关注形状 类型的组件。元素类型可以是任何内容。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1)
(I2) dimension si64 类型的常量 (C1)

Outputs

名称 类型
result si32 类型的 0 维张量

限制条件

  • (C1) 0 <= dimension < rank(operand)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

更多示例

get_tuple_element

<ph type="x-smartling-placeholder">
</ph>

语义

提取 operand 元组位于 index 位置的元素,并生成 result。更正式地说,result = operand[index]

输入

标签 名称 类型 限制条件
(I1) operand tuple (C1)、(C2)
(I2) index si32 类型的常量 (C1)、(C2)

Outputs

名称 类型 限制条件
result 任何受支持的类型 (C2)

限制条件

  • (C1) 0 <= index < size(operand)
  • (C2) type(result) = tuple_element_types(operand)[index]

示例

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

更多示例

if

语义

通过 true_branch 仅执行一个函数来生成输出,或者 false_branch,具体取决于 pred 的值。更正式地说,result = pred ? true_branch() : false_branch()

输入

标签 名称 类型 限制条件
(I1) pred i1 类型的 0 维张量
(I2) true_branch 函数 (C1 - C3)
(I3) false_branch 函数 (C1)、(C2)

Outputs

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C3)

限制条件

  • (C1) input_types(true_branch) = input_types(false_branch) = []
  • (C2) output_types(true_branch) = output_types(false_branch)
  • (C3) type(results...) = output_types(true_branch)

示例

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

更多示例

图片

语义

operand 中提取元素级虚部,并生成 result 张量。更正式地说,对于每个 x 元素: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点或复杂类型的张量 (C1)、(C2)

Outputs

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_complex(operand),则为 complex_element_type(element_type(operand))
    • 否则为 element_type(operand)

示例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

更多示例

信息流

语义

从馈入中读取数据并生成 results

infeed_config 的语义由实现定义。

results 由载荷值和令牌值组成, 最后。将来,我们计划将载荷和令牌拆分为 单独的输出以提高清晰度 (#670)。

输入

标签 名称 类型
(I1) token token
(I2) infeed_config string 类型的常量

Outputs

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C1 - C3)

限制条件

  • (C1) 0 < size(results)
  • (C2) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C3) is_token(type(results[-1]))

示例

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

更多示例

Iota

语义

使用从零开始的升序值填充 output 张量 沿iota_dimension维度更正式地说,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))

输入

标签 名称 类型 限制条件
(I1) iota_dimension si64 (C1)

Outputs

名称 类型 限制条件
output 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) 0 <= iota_dimension < rank(output)

示例

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

更多示例

is_finite

语义

执行元素级检查 x 中的值是否有限(即 +Inf、-Inf 或 NaN),并生成 y 张量。实现 isFinite 操作符合 IEEE-754 规范的要求。对于量化类型,结果为 始终为 true

输入

标签 名称 类型 限制条件
(I1) x 浮点类型的张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
y 布尔值类型的张量 (C1)

限制条件

  • (C1) shape(x) = shape(y)

示例

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

更多示例

log

语义

operand 张量执行元素级对数运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 log
  • 对于复数:复数。
  • 对于量化类型:dequantize_op_quantize(log, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

更多示例

log_plus_one

语义

operand 张量执行元素级对数加一运算,并 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 logp1
  • 对于复数:复对数加 1。
  • 对于量化类型: dequantize_op_quantize(log_plus_one, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

更多示例

物流

语义

operand 张量执行元素级逻辑运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 division(1, addition(1, exp(-x)))
  • 对于复数:复杂逻辑。
  • 对于量化类型: dequantize_op_quantize(logistic, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

更多示例

地图

<ph type="x-smartling-placeholder">
</ph>

语义

将映射函数 computation 沿 dimensions 应用于 inputs, 会生成一个 result 张量。

更正式地说,result[result_index] = computation(inputs...[result_index])

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1 - C4)
(I2) dimensions si64 类型的一维张量常量 (C3)
(I3) computation 函数 (C4)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)、(C4)

限制条件

  • (C1) shape(inputs...) = shape(result)
  • (C2) 0 < size(inputs) = N
  • (C3) dimensions = range(rank(inputs[0]))
  • (C4) computation 的类型为 (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> 其中 Ei = element_type(inputs[i])E' = element_type(result)

示例

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

更多示例

最大值

语义

对张量 lhsrhs 执行元素级最大运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 OR。
  • 对于整数:最大值为整数。
  • 对于浮点数:IEEE-754 中的 maximum
  • 对于复数:(real, imaginary) 对的字典顺序最大值。 对复数施加排序涉及令人惊讶的语义, 所以将来我们计划取消对复数的支持 (#560)。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(maximum, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

更多示例

最小值

语义

对张量 lhsrhs 执行元素级最小运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 AND。
  • 对于整数:最小值整数。
  • 对于浮点数:IEEE-754 中的 minimum
  • 对于复数:(real, imaginary) 对的字典顺序最小值。 对复数施加排序涉及令人惊讶的语义, 所以将来我们计划取消对复数的支持 (#560)。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(minimum, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

更多示例

相乘

语义

对两个张量 lhsrhs 执行元素级乘积,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 AND。
  • 对于整数:整数乘法。
  • 对于浮点数:IEEE-754 中的 multiplication
  • 对于复数:复数乘法。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(multiply, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 张量或每张量量化张量 (C1)
(I2) rhs 张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

更多示例

negate

语义

operand 张量执行元素级否定,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于有符号整数:整数求反。
  • 对于无符号整数:比特转换为有符号整数、整数求反、比特型 转换为无符号整数。
  • 对于浮点数:IEEE-754 中的 negate
  • 对于复数:采用复数否定。
  • 对于量化类型: dequantize_op_quantize(negate, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 整数、浮点、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

更多示例

语义

对张量 operand 执行元素级 NOT,并生成一个 result 张量。 根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 NOT。
  • 对于整数:按位 NOT。

参数

名称 类型 限制条件
operand 布尔值或整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

更多示例

optimization_barrier

语义

确保产生 operand 的操作在执行任何 依赖于 result 并阻止编译器转换的操作 使业务可以跨越壁垒。除此之外,操作是 身份,即 result = operand

参数

名称 类型 限制条件
operand 可变数量的张量、每个张量的量化张量或词元 (C1)

Outputs

名称 类型 限制条件
result 可变数量的张量、每个张量的量化张量或词元 (C1)

限制条件

  • (C1) type(operand...) = type(result...)

示例

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

更多示例

语义

对两个张量 lhsrhs 执行元素级 OR,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 OR。
  • 对于整数:按位 OR。

输入

标签 名称 类型 限制条件
(I1) lhs 整数或布尔值类型的张量 (C1)
(I2) rhs 整数或布尔值类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数或布尔值类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

更多示例

馈出

语义

inputs 写入输出并生成 result 令牌。

outfeed_config 的语义由实现定义。

输入

标签 名称 类型
(I1) inputs 可变数量张量或量化张量
(I2) token token
(I3) outfeed_config string 类型的常量

Outputs

名称 类型
result token

示例

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多示例

语义

通过围绕张量以及元素之间的内边距扩展 operand 具有指定 padding_value 的张量。

edge_padding_lowedge_padding_high 指定添加的内边距量 分别表示各个维度。内边距可以为负数,其中 负填充的绝对值表示要移除的元素数量 来自指定维度的数据。

interior_padding 指定任意两个之间相加的内边距 这些元素不能是负数。发生内部填充 ,以便负的边缘内边距会从 内部填充的操作数。

更正式地说,result[result_index] 的定义如下:

  • operand[operand_index](如果 result_index = edge_padding_low + operand_index * (interior_padding + 1)
  • 否则为 padding_value

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C2)、(C4)
(I2) padding_value 0 维张量或每张量量化张量 (C1)
(I3) edge_padding_low si64 类型的一维张量常量 (C1)、(C4)
(I4) edge_padding_high si64 类型的一维张量常量 (C1)、(C4)
(I5) interior_padding si64 类型的一维张量常量 (C2-C4)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C3-C6)

限制条件

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3) 0 <= interior_padding
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

示例

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

更多示例

partition_id

语义

生成当前进程的 partition_id

Outputs

名称 类型
result ui32 类型的 0 维张量

示例

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

更多示例

弹出式窗口

语义

operand 张量中设置的位数执行元素级计数 并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) operand 整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(operand) = type(result)

示例

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

更多示例

幂数

语义

rhs 张量对 lhs 张量执行元素级指数化,并且 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数指数。
  • 对于浮点数:IEEE-754 中的 pow
  • 对于复数:使用复数指数。
  • 对于量化类型:dequantize_op_quantize(power, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点、复杂类型或每张量量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

更多示例

real

语义

operand 中提取元素级实部,并生成 result 张量。更正式地说,对于每个 x 元素: real(x) = is_complex(x) ? real_part(x) : x

输入

标签 名称 类型 限制条件
(I1) operand 浮点或复杂类型的张量 (C1)、(C2)

Outputs

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(result) = shape(operand)
  • (C2) element_type(result) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_complex(operand),则为 complex_element_type(element_type(operand))
    • 否则为 element_type(operand)

示例

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

更多示例

接收

语义

使用 channel_id 从通道接收数据并生成 results

如果 is_host_transfertrue,则该操作会从 主机。否则,系统会从其他设备传输数据。这意味着 实现定义的。此标记与 channel_type,因此以后我们打算只保留其中一个 (#666)。

results 由载荷值和令牌值组成, 最后。将来,我们计划将载荷和令牌拆分为 单独的输出以提高清晰度 (#670)。

输入

标签 名称 类型 限制条件
(I1) token token (C4)
(I2) channel_id si64 类型的常量
(I3) channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE 的枚举 (C1)
(I4) is_host_transfer i1 类型的常量 (C1)

Outputs

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C2-C4)

限制条件

  • (C1) channel_type 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_host_transfer = true,则为 HOST_TO_DEVICE
    • 否则为 DEVICE_TO_DEVICE
  • (C2) 0 < size(results)
  • (C3) is_empty(result[:-1])is_tensor(type(results[:-1]))
  • (C4) is_token(type(results[-1]))

示例

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

更多示例

reduce

语义

将归约函数 body 应用于 inputsinit_values 以及 dimensions 并生成 results 张量。

归约顺序由实现定义,这意味着 bodyinit_values 必须形成一个单函数,以保证运算产生 在所有实现上针对所有输入实现相同的结果。不过,这个条件 无法适用于很多热门简化视频例如:对以下项进行浮点加法: 实际上,init_valuesbody 和 0 不会形成单元函数,因为 浮点加法不遵守结合律。

更正式地说,results...[j0, ..., jR-1] = reduce(input_slices_converted),其中:

  • input_slices = inputs...[j0, ..., :, ..., jR-1],其中 : 的插入位置 dimensions
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
  • reduce(input_slices_converted) = exec(schedule),适用于某些二元树 schedule,其中: <ph type="x-smartling-placeholder">
      </ph>
    • exec(node) = body(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule 是一种由实现定义的完整二元树,其顺序 遍历包括: <ph type="x-smartling-placeholder">
      </ph>
    • input_slices_converted...[index] 个值,适用于所有 index index_space(input_slices_converted)(按字典顺序升序排列) 共 index 个。
    • 其中穿插了实施定义的 init_values_converted 位于实现定义的位置。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1-C4)、(C6)、(C7)
(I2) init_values 0 维张量或每张量量化张量的可变数量 (C2)、(C3)
(I3) dimensions si64 类型的一维张量常量 (C4)、(C5)、(C7)
(I4) body 函数 (C6)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C3)、(C7)、(C8)

限制条件

  • (C1) same(shape(inputs...))
  • (C2) element_type(inputs...) = element_type(init_values...)
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C4) 0 <= dimensions < rank(inputs[0])
  • (C5) is_unique(dimensions)
  • (C6) body 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中 is_promotable(element_type(inputs[i]), Ei)
  • (C7) shape(results...) = shape(inputs...),但维度 未包含与 dimensions 对应的 inputs... 尺寸。
  • (C8) 对 [0,N) 中的所有 i 使用 element_type(results[i]) = Ei

示例

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

更多示例

reduce_precision

语义

operand 执行到其他浮点类型的元素级转换 使用 exponent_bitsmantissa_bits,然后再还原为原始版本 浮点类型,并生成 output 张量。

更正式地说:

  • 原始值的尾数位会更新,以对原始值进行四舍五入 值设为可使用 mantissa_bits 表示的最接近的值(使用 roundToIntegralTiesToEven 语义。
  • 那么,如果 mantissa_bits 小于 原始值,则尾数位会截断为 mantissa_bits
  • 然后,如果中间结果的指数位不适合 由 exponent_bits 提供的范围,中间结果溢出至 使用原符号输入无穷大的值,或使用 原始符号。
  • 对于量化类型,执行 dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)
(I2) exponent_bits si32 类型的常量 (C2)
(I3) mantissa_bits si32 类型的常量 (C3)

Outputs

名称 类型 限制条件
output 浮点类型的张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(output)
  • (C2) 1 <= exponent_bits
  • (C3) 0 <= mantissa_bits

示例

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

更多示例

reduce_scatter

语义

reduce_scatter

在 StableHLO 进程网格的每个进程组中,执行归约, 对来自每个进程的 operand 张量的值使用 computations, 将归约结果沿 scatter_dimension 拆分为部分, 进程之间的拆分部分,以生成 result

该操作将 StableHLO 进程网格拆分为 process_groups, 定义如下:

  • cross_replica(replica_groups) 如果 channel_id <= 0 and use_global_device_ids = false
  • cross_replica_and_partition(replica_groups) 如果 channel_id > 0 and use_global_device_ids = false
  • flattened_ids(replica_groups) 如果 channel_id > 0 and use_global_device_ids = true

之后,在每个 process_group 中:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
  • result@receiver = parts@sender[receiver_index]”中的所有sender process_group,其中 receiver_index = process_group.index(receiver)

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C2)、(C7)、(C8)
(I2) scatter_dimension si64 类型的常量 (C1)、(C2)、(C8)
(I3) replica_groups si64 类型的二维张量常量 (C3-C5)
(I4) channel_id si64 类型的常量 (C6)
(I5) use_global_device_ids i1 类型的常量 (C6)
(I6) computation 函数 (C7)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C8 - C9)

限制条件

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2) 0 <= scatter_dimension < rank(operand)
  • (C3) is_unique(replica_groups)
  • (C4) size(replica_groups) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果使用了 cross_replica,则为 num_replicas
    • 如果使用了 cross_replica_and_partition,则为 num_replicas
    • 如果使用了 flattened_ids,则为 num_processes
  • (C5) 0 <= replica_groups < size(replica_groups)
  • (C6) 如果 use_global_device_ids = true,则 channel_id > 0
  • (C7) computation 的类型为 (tensor<E>, tensor<E>) -> (tensor<E>),其中 is_promotable(element_type(operand), E)
  • (C8) shape(result) = shape(operand),以下项除外: <ph type="x-smartling-placeholder">
      </ph>
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
  • (C9) element_type(result) = E

示例

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

更多示例

reduce_window

语义

将归约函数 body 应用于 inputsinit_values 的窗口 并生成 results

下图显示了如何根据以下内容计算 results... 中的元素: inputs...

reduce_window

更正式地说, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (请参阅 reduce),其中:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations)

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2) init_values 0 维张量或每张量量化张量的可变数量 (C1)、(C13)
(I3) window_dimensions si64 类型的一维张量常量 (C4)、(C5)、(C15)
(I4) window_strides si64 类型的一维张量常量 (C6)、(C7)、(C15)
(I5) base_dilations si64 类型的一维张量常量 (C8)、(C9)、(C15)
(I6) window_dilations si64 类型的一维张量常量 (C10)、(C11)、(C15)
(I7) padding si64 类型的二维张量常量 (C12)、(C15)
(I8) body 函数 (C13)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C1)、(C14-C16)

限制条件

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N
  • (C2) same(shape(inputs...))
  • (C3) element_type(inputs...) = element_type(init_values...)
  • (C4) size(window_dimensions) = rank(inputs[0])
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(inputs[0])
  • (C7) 0 < window_strides
  • (C8) size(base_dilations) = rank(inputs[0])
  • (C9) 0 < base_dilations
  • (C10) size(window_dilations) = rank(inputs[0])
  • (C11) 0 < window_dilations
  • (C12) shape(padding) = [rank(inputs[0]), 2]
  • (C13) body 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>),其中 is_promotable(element_type(inputs[i]), Ei)
  • (C14) same(shape(results...))
  • (C15) shape(results[0]) = num_windows,其中: <ph type="x-smartling-placeholder">
      </ph>
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
  • (C16) 对 [0,N) 中的所有 i 执行了 element_type(results[i]) = Ei 操作。

示例

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

更多示例

余数

语义

对被除数 lhs 和除数 rhs 张量执行元素级取余,并且 会生成 result 张量。

更正式地说,结果的符号取自被除数, 结果的绝对值始终小于除数的绝对值。 余数按 lhs - d * rhs 计算,其中 d 的计算公式如下:

  • 对于整数:stablehlo.divide(lhs, rhs)
  • 对于浮点数:IEEE-754 中的 division(lhs, rhs),具有舍入属性 roundTowardZero
  • 对于复数:待定 (#997)。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(remainder, lhs, rhs, type(result))

对于浮点元素类型,此运算与 remainder 运算符合 IEEE-754 规范,其中 d 为整数值 最接近 lhs/rhs 的精确值,并且与偶数相等。

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点、复杂类型或每张量量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

更多示例

replica_id

语义

生成当前进程的 replica_id

Outputs

名称 类型
result ui32 类型的 0 维张量

示例

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

更多示例

调整形状

语义

operand 张量变形为 result 张量。从概念上讲, 也就是保持相同的规范表示法 形状,例如从tensor<2x3xf32>tensor<3x2xf32>tensor<6xf32>

更正式地说,result[result_index] = operand[operand_index],其中 result_indexoperand_index 在字典中的位置相同 index_space(result)index_space(operand) 的顺序。

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C3)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1 - C3)

限制条件

  • (C1) element_type(result) 由以下公式计算得出: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand)(如果 !is_per_axis_quantized(operand))。
    • element_type(operand),但 quantization_dimension(operand) 和 否则,quantization_dimension(result) 可能会有所不同。
  • (C2) size(operand) = size(result)
  • (C3) 如果 is_per_axis_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)

示例

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

更多示例

reverse

语义

反转 operand 中元素的顺序(沿着指定的 dimensions) 并生成一个 result 张量。更正式地说, result[result_index] = operand[operand_index],其中:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 如果 dimensions 中的 d
  • 否则为 operand_index[d] = result_index[d]

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1)、(C3)
(I2) dimensions si64 类型的一维张量常量 (C2)、(C3)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)、(C3)

限制条件

  • (C1) type(operand) = type(result)
  • (C2) is_unique(dimensions)
  • (C3) 0 <= dimensions < rank(result)

示例

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

更多示例

rng

<ph type="x-smartling-placeholder">
</ph>

语义

使用 rng_distribution 算法生成随机数字,并生成一个 给定形状 shaperesult 张量。

如果为 rng_distribution = UNIFORM,则生成随机数字 分布在区间[a, b)内的均匀分布。如果为 a >= b, 行为未定义

如果为 rng_distribution = NORMAL,则生成随机数字 遵循正态分布,平均值 = a,标准差 = b。 如果为 b < 0,则行为未定义。

生成随机数的确切方式由实现定义。对于 例如,它们不一定具有确定性, 隐藏状态。

在与许多利益相关者的对话中,这一行动的确有效 所以未来我们计划探索如何将其移除 (#597)。

输入

标签 名称 类型 限制条件
(I1) a 整数、布尔值或浮点类型的 0 维张量 (C1)、(C2)
(I2) b 整数、布尔值或浮点类型的 0 维张量 (C1)、(C2)
(I3) shape si64 类型的一维张量常量 (C3)
(I4) rng_distribution UNIFORMNORMAL 的枚举 (C2)

Outputs

名称 类型 限制条件
result 整数、布尔值或浮点类型的张量 (C1 - C3)

限制条件

  • (C1) element_type(a) = element_type(b) = element_type(result)
  • (C2) 如果 rng_distribution = NORMAL,则 is_float(a)
  • (C3) shape(result) = shape

示例

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

语义

返回一个填充了统一随机位和更新后的输出状态的 output 使用伪随机数生成器算法 rng_algorithmoutput_state 初始状态为 initial_state。输出结果肯定是 initial_state 的确定性函数,但不一定是 确定性。

rng_algorithm 是以下值之一:

  • DEFAULT:实现定义的算法。
  • THREE_FRY:Threefry 算法的实现定义的变体。*
  • PHILOX:Pilox 算法的实现定义的变体。*

* 请参阅:Salmon 等人SC 2011。并行随机数字:简单到 1、2、3。

输入

标签 名称 类型 限制条件
(I1) rng_algorithm DEFAULTTHREE_FRYPHILOX 的枚举 (C2)
(I2) initial_state ui64 类型的一维张量 (C1)、(C2)

Outputs

名称 类型 限制条件
output_state ui64 类型的一维张量 (C1)
output 整数或浮点类型的张量

限制条件

  • (C1) type(initial_state) = type(output_state)
  • (C2) size(initial_state) 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 rng_algorithm = DEFAULT,则由实现定义。
    • 如果 rng_algorithm = THREE_FRY,则为 2
    • 23(如果 rng_algorithm = PHILOX)。

示例

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

语义

对最接近的整数执行元素级舍入,消除相等的值 在 operand 张量上从零开始,并生成一个 result 张量。实现 IEEE-754 规范中的 roundToIntegralTiesToAway 运算。对于 量化类型、执行 dequantize_op_quantize(round_nearest_afz, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

更多示例

round_nearest_even

语义

对最接近的整数执行元素级四舍五入,取消相等 在 operand 张量上向偶整数赋值,并生成一个 result 张量。实现来自 IEEE-754 的 roundToIntegralTiesToEven 操作 规范对于量化类型, dequantize_op_quantize(round_nearest_even, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点类型的张量或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点类型的张量或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

更多示例

rsqrt

语义

operand 张量执行元素级倒数平方根运算,并且 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 rSqrt
  • 对于复数:复数倒数平方根。
  • 对于量化类型:dequantize_op_quantize(rsqrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

更多示例

散点图

语义

生成 results 张量,这些张量等于 inputs 张量,但以下情况除外: 使用值更新由 scatter_indices 指定的多个切片 updates(使用 update_computation)。

下图显示了 updates... 中的元素如何映射到 results...。该图选择了几个示例 updates... 索引,并详细说明它们分别是哪些 results... 索引 对应的版本。

散点图

更正式地说,对于 index_space(updates[0]) 中的所有 update_index

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
  • update_scatter_index = update_index[update_scatter_dims...]
  • start_index 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • scatter_indices[si0, ..., :, ..., siN],其中 si 为个人 update_scatter_index: 中的元素将插入到 index_vector_dim 索引(如果 index_vector_dim <) rank(scatter_indices)
    • 否则为 [scatter_indices[update_scatter_index]]
  • 对于axes(inputs[0])d_input, <ph type="x-smartling-placeholder">
      </ph>
    • full_start_index[d_input] = start_index[d_start](如果 d_input = scatter_dims_to_operand_dims[d_start]
    • 否则为 full_start_index[d_input] = 0
  • 对于axes(inputs[0])d_input, <ph type="x-smartling-placeholder">
      </ph>
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] 如果 d_input = input_batching_dims[i_batching]d_start = scatter_indices_batching_dims[i_batching]
    • 否则为 full_batching_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...]
  • full_window_index = [wi0, ..., 0, ..., wiN],其中 wi 为个人 元素在 update_window_index 中,而 0 在索引位置处插入 inserted_window_dimsinput_batching_dims
  • result_index = full_start_index + full_batching_index + full_window_index

鉴于此,results = exec(schedule, inputs),其中:

  • schedule 是实现定义的 index_space(updates[0])
  • exec([update_index, ...], results) = exec([...], updated_results),其中: <ph type="x-smartling-placeholder">
      </ph>
    • 如果result_indexshape(results...)的边界之内
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results 是包含 results...[result_index]results 的副本 设为 updated_values...
    • 否则
    • updated_results = results
  • exec([], results) = results

如果 indices_are_sortedtrue,则实现可以假定: scatter_indices 相对于 scatter_dims_to_operand_dims 进行排序, 否则行为将处于未定义状态。更正式地说,对于来自以下广告的所有i1 < i2indices(result)full_start_index(i1) <= full_start_index(i2)

如果 unique_indicestrue,则实现可以假定所有 分散到的 result_index 索引是唯一的。如果unique_indices true,但分散到的索引不是唯一的,则行为 未定义。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1)、(C2)、(C4-C6)、(C11)、(C13)、(C18)、(C21)、(C23-C24)
(I2) scatter_indices 整数类型的张量 (C4)、(C15)、(C19)、(C22)
(I3) updates 可变数量张量或每个张量量化张量 (C3-C6)、(C8)
(I4) update_window_dims si64 类型的一维张量常量 (C2)、(C4)、(C7-C8)
(I5) inserted_window_dims si64 类型的一维张量常量 (C2)、(C4)、(C9-C11)
(I6) input_batching_dims si64 类型的一维张量常量 (C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20)
(I7) scatter_indices_batching_dims si64 类型的一维张量常量 (C14 - C18)
(I8) scatter_dims_to_operand_dims si64 类型的一维张量常量 (C19 至 C21)
(I9) index_vector_dim si64 类型的常量 (C4)、(C16)、(C19)、(C22)
(I10) indices_are_sorted i1 类型的常量
(I11) unique_indices i1 类型的常量
(I12) update_computation 函数 (C23)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C24 - C25)

限制条件

  • (C1) same(shape(inputs...))
  • (C2) `rank(inputs[0]) = 大小(update_window_dims) + 大小(inserted_window_dims) <ph type="x-smartling-placeholder">
      </ph>
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...))
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes),其中: <ph type="x-smartling-placeholder">
      </ph>
    • update_scatter_dim_sizes = shape(scatter_indices) 除外 scatter_indices 的尺寸大小,对应于 未包含 index_vector_dim
    • update_window_dim_sizes <= shape(inputs[0]) 除外 与 inserted_window_dims 对应的 inputs[0] 尺寸尺寸 和 input_batching_dims 未包含在内。
    • combineupdate_scatter_dim_sizes 放置在 update_scatter_dimsupdate_window_dim_sizes 位于相应的轴上 发送至 update_window_dims
  • (C5) 0 < size(inputs) = size(updates) = N
  • (C6) element_type(updates...) = element_type(inputs...)
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8) 0 <= update_window_dims < rank(updates[0])
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims)
  • (C11) 0 <= inserted_window_dims < rank(inputs[0])
  • (C12) is_sorted(input_batching_dims)
  • (C13) 0 <= input_batching_dims < rank(inputs[0]))
  • (C14) is_unique(scatter_indices_batching_dims)
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices)
  • (C16) index_vector_dim not in scatter_indices_batching_dims
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims)
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices)
  • (C23) update_computation 的类型为 (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), 其中 is_promotable(element_type(inputs[i]), Ei)
  • (C24) shape(inputs...) = shape(results...)
  • (C25) element_type(results[i]) = Ei 表示 [0,N) 中的所有 i

示例

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

更多示例

select

语义

生成一个 result 张量,其中每个元素都选自 on_trueon_false 张量。pred 更正式地说,为 result[result_index] = pred_element ? on_true[result_index] : on_false[result_index],其中 pred_element = rank(pred) = 0 ? pred[] : pred[result_index]。对于量化类型, dequantize_select_quantize(pred, on_true, on_false, type(result))

输入

标签 名称 类型 限制条件
(I1) pred i1 类型的张量 (C1)
(I2) on_true 张量或每张量量化张量 (C1 - C2)
(I3) on_false 张量或每张量量化张量 (C2)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C2)

限制条件

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

示例

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

更多示例

select_and_scatter

语义

根据sourcescatter 使用 selectinput 张量进行 reduce_window 运算的结果,并生成 result 张量。

下图显示了如何根据以下内容计算 result 中的元素: operandsource

select_and_scatter

更正式地说:

  • selected_values = reduce_window_without_init(...),其中包含以下输入:

    • inputs = [operand].
    • window_dimensionswindow_stridespadding,按原样使用。
    • base_dilations = windows_dilations = 1
    • body 的定义如下:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    E = element_type(operand)reduce_window_without_init的工作地点 与 reduce_window 完全相同,但底层的 schedule reduce(请参阅 reduce)不包含 init 值。目前 未指定如果对应窗口没有值会发生什么 (#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) 其中:

    • source_values = [source[source_index] for source_index in source_indices]
    • selected_index(source_index) = operand_index(如果 selected_values[source_index] 包含 operand 元素 operand_index起。
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1-C4)、(C6)、(C8-C11)
(I2) source 张量或每张量量化张量 (C1)、(C2)
(I3) init_value 0 维张量或每张量量化张量 (C3)
(I4) window_dimensions si64 类型的一维张量常量 (C2)、(C4)、(C5)
(I5) window_strides si64 类型的一维张量常量 (C2)、(C6)、(C7)
(I6) padding si64 类型的二维张量常量 (C2)、(C8)
(I7) select 函数 (C9)
(I8) scatter 函数 (C10)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C11 至 C12)

限制条件

  • (C1) element_type(operand) = element_type(source)
  • (C2) shape(source) = num_windows,其中: <ph type="x-smartling-placeholder">
      </ph>
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
  • (C3) element_type(init_value) = element_type(operand)
  • (C4) size(window_dimensions) = rank(operand)
  • (C5) 0 < window_dimensions
  • (C6) size(window_strides) = rank(operand)
  • (C7) 0 < window_strides
  • (C8) shape(padding) = [rank(operand), 2]
  • (C9) select 的类型为 (tensor<E>, tensor<E>) -> tensor<i1>,其中 E = element_type(operand)
  • (C10) scatter 的类型为 (tensor<E>, tensor<E>) -> tensor<E>,其中 is_promotable(element_type(operand), E)
  • (C11) shape(operand) = shape(result)
  • (C12) element_type(result) = E

示例

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

更多示例

send

语义

inputs 发送到通道 channel_id,并生成 result 令牌。

如果 is_host_transfertrue,相应操作会将数据传输到 主机。否则,它会将数据传输到其他设备。这意味着 实现定义的。此标记与 channel_type,因此以后我们打算只保留其中一个 (#666)。

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或量化张量
(I2) token token
(I3) channel_id si64 类型的常量
(I4) channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST 的枚举 (C1)
(I5) is_host_transfer i1 类型的常量 (C1)

Outputs

名称 类型
result token

限制条件

  • (C1) channel_type 的定义如下: <ph type="x-smartling-placeholder">
      </ph>
    • 如果 is_host_transfer = true,则为 DEVICE_TO_HOST
    • 否则为 DEVICE_TO_DEVICE

示例

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

更多示例

shift_left

语义

rhs 数值对 lhs 张量执行元素级左移运算 并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

更多示例

shift_right_arithmetic

语义

lhs 张量执行元素级算术右移运算,方法如下: rhs 位数并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

更多示例

shift_right_logical

语义

通过 rhslhs 张量执行元素级逻辑右移运算 并生成一个 result 张量。

输入

标签 名称 类型 限制条件
(I1) lhs 整数类型的张量 (C1)
(I2) rhs 整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

更多示例

签名

语义

返回 operand 元素级的符号,并生成 result 张量。 更正式地说,对于每个元素 x,相应语义可以使用 Python 语法如下所示:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

对于量化类型, dequantize_op_quantize(sign, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 带符号整数、浮点数、复杂类型或每张量量化张量的张量 (C1)

Outputs

名称 类型 限制条件
result 带符号整数、浮点数、复杂类型或每张量量化张量的张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

更多示例

正弦

语义

operand 张量执行元素级正弦运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 sin
  • 对于复数:复数正弦。
  • 对于量化类型:dequantize_op_quantize(sine, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

更多示例

slice

语义

使用静态计算的起始索引从 operand 中提取 Slice 并生成一个 result 张量。start_indices 包含 每个维度的切片,limit_indices 包含结束索引 (不含)针对每个维度的切片,strides 包含步长 数据。

更正式地说,result[result_index] = operand[operand_index],其中 operand_index = start_indices + result_index * strides

输入

标签 名称 类型 限制条件
(I1) operand 张量或每张量量化张量 (C1-C3)、(C5)
(I2) start_indices si64 类型的一维张量常量 (C2)、(C3)、(C5)
(I3) limit_indices si64 类型的一维张量常量 (C2)、(C3)、(C5)
(I4) strides si64 类型的一维张量常量 (C2)、(C4)

Outputs

名称 类型 限制条件
result 张量或每张量量化张量 (C1)、(C5)

限制条件

  • (C1) element_type(operand) = element_type(result)
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand)
  • (C4) 0 < strides
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides)

示例

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

更多示例

排序

语义

沿着维度 dimensioninputs 的一维切片一起排序, 根据 comparator 生成 results

与其他运算中的类似输入不同,dimension 允许负值, 。以后,系统可能会禁止执行此操作 以确保一致性 (#1377)。

如果 is_stable 为 true,则排序是稳定的,即相对顺序 比较器视为相等的元素将被保留。适用情形 其中,只有一个输入,e1e2 这两个元素会被视为 当且仅当满足以下条件时,比较运算符等于 comparator(e1, e2) = comparator(e2, e1) = false。请参阅以下规范 了解它如何泛化到多个输入。

更正式地说,对于 index_space(results[0]) 中的所有 result_index

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
  • result_slice = [ri0, ..., :, ..., riR-1],其中 riN 为个人 result_index 中的元素,并在 adjusted_dimension 处插入 :
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...)
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
  • 其中 sort 以非降序方式对一维切片进行排序 如果左侧参数为 ,则 comparator_together 会返回 true 小于右侧的秒参数。
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together

输入

标签 名称 类型 限制条件
(I1) inputs 可变数量张量或每个张量量化张量 (C1 - C5)
(I2) dimension si64 类型的常量 (C4)
(I3) is_stable i1 类型的常量
(I4) comparator 函数 (C5)

Outputs

名称 类型 限制条件
results 可变数量张量或每个张量量化张量 (C2)、(C3)

限制条件

  • (C1) 0 < size(inputs)
  • (C2) type(inputs...) = type(results...)
  • (C3) same(shape(inputs...) + shape(results...))
  • (C4) -R <= dimension < R,其中 R = rank(inputs[0])
  • (C5) comparator 包含类型 (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, 其中 Ei = element_type(inputs[i])

示例

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

更多示例

平方

语义

operand 张量执行元素级平方根运算,并生成一个 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 squareRoot
  • 对于复数:复数平方根。
  • 对于量化类型:dequantize_op_quantize(sqrt, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

更多示例

subtract

语义

对两个张量 lhsrhs 执行元素级减法,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于整数:整数减法。
  • 对于浮点数:IEEE-754 中的 subtraction
  • 对于复数:复数减法。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(subtract, lhs, rhs, type(result))

输入

标签 名称 类型 限制条件
(I1) lhs 整数、浮点、复杂类型或每张量量化张量 (C1)
(I2) rhs 整数、浮点、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 整数、浮点、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

示例

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

更多示例

tan

语义

operand 张量执行元素级正切运算,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 tan
  • 对于复数:复数正切。
  • 对于量化类型:dequantize_op_quantize(tan, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

更多示例

坦赫

语义

operand 张量执行元素级双曲正切运算,并且 会生成 result 张量。根据元素类型,执行以下操作:

  • 对于浮点数:IEEE-754 中的 tanh
  • 对于复数:复双曲正切。
  • 对于量化类型: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(tanh, operand, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点型、复杂类型或每张量量化张量 (C1)

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_type(operand) = baseline_type(result)

示例

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

更多示例

转置

语义

使用 permutation 排列 operand 张量的维,并生成 result 张量。更正式地说,result[result_index] = operand[operand_index] 其中 result_index[d] = operand_index[permutation[d]]

输入

标签 名称 类型 限制条件
(I1) operand 张量或量化张量 (C1 - C4)
(I2) permutation si64 类型的一维张量常量 (C2-C4)

Outputs

名称 类型 限制条件
result 张量或量化张量 (C1)、(C3-C4)

限制条件

  • (C1) element_type(result) 由以下公式计算得出: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand)(如果 !is_per_axis_quantized(operand))。
    • element_type(operand),但 quantization_dimension(operand) 和 否则,quantization_dimension(result) 可能会有所不同。
  • (C2) permutationrange(rank(operand)) 的排列。
  • (C3) shape(result) = dim(operand, permutation...)
  • (C4) 如果 is_per_axis_quantized(result),则: quantization_dimension(operand) = permutation(quantization_dimension(result))

示例

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

更多示例

triangular_solve

语义

求解带有下三角或上三角的批量线性方程组 创建系数矩阵。

更正式地说,鉴于 abresult[i0, ..., iR-3, :, :] 是解决方案 当left_sideop(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]truex * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] left_sidefalse,求解变量 x,其中 op(a) 已确定 截止到 transpose_a,可以是下列选项之一:

  • NO_TRANSPOSE:按原样使用 a 执行操作。
  • TRANSPOSE:对 a 的转置执行操作。
  • ADJOINT:对 a 的共振转置执行操作。

如果 lowertrue,则系统仅从 a 的下三角形读取输入数据;或者 否则为 a 的上三角形。输出数据在同一三角形中返回; 另一个三角形中的值则由实现定义。

如果 unit_diagonal 为 true,该实现可以假定对角线 a 的元素等于 1,否则行为未定义。

对于量化类型, dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))

输入

标签 名称 类型 限制条件
(I1) a 浮点型、复杂类型或每张量量化张量 (C1 - C3)
(I2) b 浮点型、复杂类型或每张量量化张量 (C1 - C4)
(I3) left_side i1 类型的常量 (C3)
(I4) lower i1 类型的常量
(I5) unit_diagonal i1 类型的常量
(I6) transpose_a NO_TRANSPOSETRANSPOSEADJOINT 的枚举

Outputs

名称 类型 限制条件
result 浮点型、复杂类型或每张量量化张量 (C1)

限制条件

  • (C1) baseline_element_type(a) = baseline_element_type(b)
  • (C2) 2 <= rank(a) = rank(b) = R
  • (C3) shape(a)shape(b) 之间的关系如下: <ph type="x-smartling-placeholder">
      </ph>
    • shape(a)[:-3] = shape(b)[:-3]
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
  • (C4) baseline_type(b) = baseline_type(result)

示例

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

<ph type="x-smartling-placeholder">
</ph>

语义

根据值 val 生成一个 result 元组。

输入

标签 名称 类型 限制条件
(I1) val 值的可变数量 (C1)

Outputs

名称 类型 限制条件
result tuple (C1)

限制条件

  • (C1) result 的类型为 tuple<E0, ..., EN-1>,其中 Ei = type(val[i])

示例

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

更多示例

uniform_dequantize

语义

对量化张量 operand 执行元素级转换 浮点张量 result,根据定义的量化参数 按 operand 类型排序。

更正式地说,result = dequantize(operand)

输入

标签 名称 类型 限制条件
(I1) operand 量化张量 (C1)、(C2)

Outputs

名称 类型 限制条件
result 浮点类型的张量 (C1)、(C2)

限制条件

  • (C1) shape(operand) = shape(result)
  • (C2) element_type(result) = expressed_type(operand)

示例

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

语义

对浮点张量或量化张量执行元素级转换 根据量化将 operand 映射到量化张量 resultresult 类型定义的参数。

更正式地说,

  • 如果为 is_float(operand): <ph type="x-smartling-placeholder">
      </ph>
    • result = quantize(operand, type(result))
  • 如果为 is_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • float_result = dequantize(operand)
    • result = quantize(float_result, type(result))

输入

标签 名称 类型 限制条件
(I1) operand 浮点或量化类型的张量 (C1)、(C2)

Outputs

名称 类型 限制条件
result 量化张量 (C1)、(C2)

限制条件

  • (C1) shape(operand) = shape(result)
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

示例

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

语义

body 函数执行 0 次或多次后生成输出, cond 函数会输出 true。更正式地说,语义 如下所示:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

无限循环的行为待定 (#383)。

输入

标签 名称 类型 限制条件
(I1) operand 可变数量的张量、量化张量或词元 (C1 - C3)
(I2) cond 函数 (C1)
(I3) body 函数 (C2)

Outputs

名称 类型 限制条件
results 可变数量的张量、量化张量或词元 (C3)

限制条件

  • (C1) cond 的类型为 (T0, ..., TN-1) -> tensor<i1>,其中 Ti = type(operand[i])
  • (C2) body 的类型为 (T0, ..., TN-1) -> (T0, ..., TN-1),其中 Ti = type(operand[i])
  • (C3) type(results...) = type(operand...)

示例

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

更多示例

XOR

语义

对两个张量 lhsrhs 执行元素级 XOR,并生成 result 张量。根据元素类型,执行以下操作:

  • 对于布尔值:逻辑 XOR。
  • 对于整数:按位 XOR。

输入

标签 名称 类型 限制条件
(I1) lhs 布尔值或整数类型的张量 (C1)
(I2) rhs 布尔值或整数类型的张量 (C1)

Outputs

名称 类型 限制条件
result 布尔值或整数类型的张量 (C1)

限制条件

  • (C1) type(lhs) = type(rhs) = type(result)

示例

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

更多示例

方言互操作性

目前,实际使用的 StableHLO 程序有时包含 并非由 StableHLO 定义。

模块、函数、调用和返回

StableHLO 为 ModuleOp、FuncOp、CallOp 和 退货这样做是为了更好地与现有 MLIR 机制的互操作性, 实用的卡券包括针对 FuncOp 和 ModuleOp 的写入,以及许多编译 流水线期望这些 op 存在。全面的兼容性保证 应用于这些运算如果这些运维套件在 不兼容(即移除),将添加 StableHLO 等效项以保留 兼容性。

CHLO

CHLO 操作集包含分解为 StableHLO 的更高级别的操作。 目前,没有针对 CHLO 的兼容性保证。兼容性 chlo-legalize-to-stablehlo 通行证 必须在序列化之前使用

形状操作

在社区中,使用核心的某些操作是 动态 StableHLO 程序中的 MLIR 方言,用于执行形状计算。 最常见的语言包括shape方言 操作,例如 shape_ofnum_elementstensor 方言 操作(如 dimfrom_elements)和内置的 index 类型。

Dynamism RFC >O2 表示这些不在范围内,但也有一些对 index 类型的支持 包含它们是为了实现互操作性我们不提供以下方面的兼容性保证 操作或类型。shape-legalize-to-stablehlo 可将这些操作转换为完全受支持的 StableHLO 操作。

已弃用的操作

有多个 StableHLO 操作继承自 MHLO 这些版本已被弃用,并将逐步淘汰 StableHLO。全面、详细地了解 您可以在 StableHLO v1.0 清理操作 #2283 中找到移除步骤。 与弃用相关的跟踪器问题是 #2340

这些操作分为以下几个类别:

  • “不在 HLO 中”StableHLO 操作类别 - 它们最初是 StableHLO 操作集,但后来被认为不适合它: broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum#3)。
  • 未使用的操作 - 这些操作在某个时间点可能有用,但 要么是不发达,要么是使用这些运算的流水线 已重构为不再需要它们这包括 maptuple (#598)、 get_tuple_elementrngcomplex 比较 #560、 和卷积 window_reversal (#1181)。

其中一些运算可以用 现有操作(broadcastcreate_tokencross-replica-sumdotunary_einsum),并且将在现有兼容性窗口期后移除 卡券(6 个月)。其他 einsum 个网址仍在接受移除, get_tuple_elementmaprng torch_index_selecttuplecomplex 比较,window_reversal)。待社区反馈 这些操作要么会被移除,要么添加到规范中并获得全面支持。直到 这些 op 是已知的,仅保证 6 个月的兼容性。

执行

顺序执行

StableHLO 程序通过向 main 函数提供输入值来执行 和计算输出值。函数的输出值由 执行位于相应 return 操作的根操作图。

执行顺序由实现定义,只要符合 数据流,即操作是否在使用前执行。在 StableHLO 中 附带效应操作消耗一个词元并生成一个词元(多个词元可以 通过 after_all 多路复用到一个令牌中),因此侧的执行顺序 效果也与 Dataflow 保持一致。例如,在以下程序中 有两种可能的执行顺序:%0%1%2return%1%0%2return

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

更正式地说,StableHLO 进程是以下各项的组合: 1) StableHLO 程序;2) 运行状态(尚未执行); 以及 3) 流程正在处理的中间值。 该过程从 main 函数的输入值开始,然后 更新操作状态和中间值的操作图,以及 完成输出值。进一步正式化待定 (#484)。

并行执行

StableHLO 程序可以并行执行,并整理成 2D 进程网格 (共有 num_replicas 列,按num_partitions列,类型均为ui32类型)。

StableHLO 进程网格中,StableHLO 的 num_replicas * num_partitions 多个进程同时执行每个进程都有一个 process_id = (replica_id, partition_id),其中 replica_idreplica_ids = range(num_replicas)partition_ids = range(num_partitions)partition_id,两者都有 类型 ui32

对于每个程序(在 未来,我们计划将其明确纳入 StableHLO 计划 #650),并将位置 对于每个进程而言都是静态已知的每个进程都有 通过 replica_id 获取其在进程网格中的位置, partition_id 次操作。

在进程网格中,所有项目都可以相同(在“单个项目”的 程序、多个数据"样式),都可以互不相同(在“多重节目”中, 多个数据"样式)或两者之间的内容。未来,我们计划 引入对定义并行 StableHLO 程序的其他习语的支持, 包括 GSPMD (#619)。

在进程网格中,进程大多彼此独立, 它们具有单独的操作状态、单独的输入/中间/输出值 而且大部分操作都是在进程之间单独执行的, 下面介绍的少数集体操作除外。

鉴于大多数操作的执行仅使用来自 通过名称引用这些值通常没有歧义。 不过,在描述集体运算的语义时,这样还不够, 引发 name@process_id 表示法以引用 name 值 特定进程内的资源(从这个角度来看,不符合条件的 name 可能 视为 name@(replica_id(), partition_id()) 的简写形式)。

各进程的执行顺序由实现定义,但 通过点对点通信和集体运算引入的同步 如下所述。

点对点通信

StableHLO 进程可以通过 StableHLO 频道。渠道由类型为正数的 ID 表示 si64。通过各种操作,可以将值发送到通道和 从频道接收它们。

进一步规范化,例如这些频道 ID 的来源、 程序、程序、程序、程序、程序、程序、 由他们介绍,待定 (#484)。

流式通信

每个 StableHLO 进程都可以访问两个流处理接口:

  • 可供读取的信息流广告
  • 可以写入的馈出

与渠道不同,渠道用于在进程之间进行通信, 两端都有流程,Feed 和出 Feed 都有各自的 最终实现定义。

进一步规范化,例如流式通信如何影响执行 以及由此引入的同步类型,待定 (#484)。

集体行动

StableHLO 中有六个集体操作:all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter。所有这些操作都会在 StableHLO 进程中拆分进程 网格划分为 StableHLO 进程组,并在该进程组内执行联合计算, 独立于其他进程组

在每个进程组内,集体操作可以引入同步 屏障。进一步规范化,例如我们详细说明了 以及进程究竟是如何到达这个屏障的 如果不采取应对措施,则待定 (#484)。

如果进程组涉及跨分区通信,即存在 进程组内分区 ID 不同的进程,则执行 集体操作需要一个通道,而该集合操作必须提供一个 si64类型的正数channel_id。不需要跨副本通信 渠道。

集体操作执行的计算特定于单个操作 并在上面的各个操作部分进行了介绍。然而,这些 将进程网格拆分为进程组的进程在这些操作之间共享 具体内容。更正式地说,StableHLO 支持 。

cross_replica

只有跨副本通信发生在每个进程组内。这个 策略接受 replica_groups(副本 ID 列表),并计算 partition_idsreplica_groups 的笛卡尔积。replica_groups 必须具有唯一元素,并涵盖所有 replica_ids。更正式地说,使用 Python 语法:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,对于 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica将生成 [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]

cross_partition

只有跨分区通信在每个进程组内发生。这个 策略使用 partition_groups(分区 ID 列表),以及 按 replica_ids 计算 partition_groups 的笛卡尔积。 partition_groups 必须具有唯一元素,并涵盖所有 partition_ids。 更正式地说,使用 Python 语法:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

例如,对于 partition_groups = [[0, 1]]num_replicas = 4cross_partition将生成 [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]

cross_replica_and_partition

跨副本和跨分区通信可能在每个副本中发生, 进程组此策略使用replica_groups,即 副本 ID - 并按以下方式计算每个 replica_group 的笛卡尔积 partition_idsreplica_groups必须包含唯一元素且涵盖所有元素 replica_ids。更正式地说,使用 Python 语法:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

例如,对于 replica_groups = [[0, 1], [2, 3]]num_partitions = 2cross_replica_and_partition将生成 [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]

flattened_ids

此策略采用 flattened_id_groups -“扁平化”列表 采用 replica_id * num_partitions + partition_id 形式的进程 ID - 以及 并将其转换为进程 ID。flattened_id_groups必须包含唯一元素 并涵盖所有process_ids。更正式地说,使用 Python 语法:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

例如,对于 flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2flattened_ids 将生成 [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]

准确率

目前,StableHLO 不保证数值准确性, 但未来可能会发生变化 (#1156)。

量化运算的执行语义

量化 StableHLO 运算的解释可能会因 硬件要求和功能。例如,某些硬件可能会选择 使用“去量化、执行浮点数”和 运算,最后进行量化”策略还有一些人可能完整执行 使用整数算术进行计算。因此,解释 量化 StableHLO 运算完全由特定的 实施。混合量化的解释 (#1575) 应基于 它的语义(通过 1792)。

错误

StableHLO 程序通过一系列针对 从而在运行之前排除了许多类别的错误。 不过,可能会出现错误情况,例如直到整数溢出, 出界访问等。除非明确指出,否则所有这些错误都是 会导致实现定义的行为, 将来 (#1157)。

浮点异常

作为此规则的一个例外情况,StableHLO 程序中的浮点异常 具有明确定义的行为。导致由 IEEE-754 标准(无效运算、除以零、溢出、下溢或 不精确异常)生成默认结果(如标准中所定义),并且 继续执行,而不引发相应的状态标记;类似于 来自标准的 raiseNoFlag 异常处理。非标准广告的例外情况 运算(例如复数算术和某些先验函数) 实现定义的。

形状不匹配

StableHLO 支持动态形状的张量。但是,形状必须 运行时,否则行为将处于未定义状态。StableHLO 未 提供一个操作,用于断言张量在运行时具有给定形状。 生成正确的代码由提供方负责。

作为一个具体示例,以下程序是有效的。不过,在运行时 %arg0%arg1 的确切形状必须相同,否则 未定义程序的行为:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

为描述语法,本文档使用经过修改的 EBNF ISO 变种 语法 (ISO/IEC 14977:1996Wikipedia)、 进行了两项修改:1) 规则使用 ::=(而不是 =)定义;

2) 串联使用并列(而非 ,)来表示。

用于描述语义(即在“类型”、“常量”和“操作”部分中), 我们使用的公式基于 Python 语法,并且支持扩展 以简洁地表示数组操作,如下所示。效果很好 但在极少数情况下 我们使用始终明确引入的 vanilla Python 语法。

公式

让我们根据 dot_general 中的示例来探索公式的工作原理 规范此操作的一个限制条件如下所示: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

此公式中使用的名称有两个来源:1) 全局函数; 即 dim、2) 相应计划元素的成员定义,即 lhslhs_batching_dimensionsrhsrhs_batching_dimensions 输入 “输入”部分中部分dot_general

如前所述,此公式的语法基于 Python,部分 简洁型扩展为了理解这个公式,让我们将 转换成普通 Python 语法

A) 在这些公式中,我们使用 = 表示等式,因此第一步 为了获取 Python 语法,请将 = 替换为 ==,如下所示: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) 此外,这些公式支持将标量表达式转换为省略号 (...) 张量表达式。简而言之,f(xs...) 大致表示“每个 张量 xs 中的标量 x,计算标量 f(x),然后返回所有 这些标量结果一起形成张量结果”。在原版 Python 语法中 我们的示例公式将变成: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

由于使用省略号,因此通常可以避免在 单个标量。然而,在一些棘手的情况下,较低级别的半非正式用语 语法可以像在 start_indices[bi0, ..., :, ..., biN] 公式中一样使用 。gather为确保简洁性,我们不会 提供了确切的正式形式,用于将此类语法转换为原始 Python 语言, 但希望每次使用都能直观易懂。 如果某些特定的公式看起来不透明,请告知我们,我们将尽力 改进它们。

此外,您会注意到,公式会使用省略号来展开各种列表, 包括张量、张量列表(例如,可能 张量数量等) 形式化(例如,列表甚至不是 StableHLO 类型系统的一部分)和 而是依靠直观的可理解性。

C) 最后一种值得注意的记号工具是内含 广播。虽然 StableHLO 运算集不支持隐式广播, 公式的做法同样如此,同时也是为了简洁性。简而言之,如果标量 在需要张量的上下文中使用,标量会广播到 预期形状。

继续以 dot_general 为例,以下是另一个限制条件: 0 <= lhs_batching_dimensions < rank(lhs)。如 dot_general 中所定义, lhs_batching_dimensions 是一个张量,但是 0rank(lhs) 是标量。应用隐式广播后,公式会 会变为 [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]

应用于特定的 dot_general 运算时,此公式将 计算为布尔值张量。将公式用作约束条件时, 如果公式的计算结果为 true 或张量,则约束条件为 仅包含 true 元素。

名称

在公式中,词法范围包括:1) 全局函数;2) 成员定义;

3) 局部定义。全局函数列表如下所示。列表 取决于表示法为 已应用于:

  • 对于操作,成员定义包括“输入源”中引入的名称和 “输出”部分。
  • 对于其他所有内容,成员定义包括 节目元素,以相应的 EBNF 非终端命名。大部分 可通过以下方式获得这些结构部件的名称: 非终端名称变成蛇形命名法(例如 IntegerLiteral =>) integer_literal),但有时名称在流程中会缩写(例如 QuantizationStorageType =>storage_type),在这种情况下名称为 以与“Inputs”(输入)方式明确相似的方式引入/"输出"运行中的部分 。
  • 此外,成员定义始终包含 self 来引用 相应的计划元素。

计算公式时,它们可以使用以下类型的值: 1) Value(实际值,例如 dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; 始终知道自己所属的类型)、 2) Placeholder(期值,如 lhsrhsresult;其实际值 值还未知,只有它们的类型是已知的), 3) Type(“类型”部分中定义的类型), 4) Function(“函数”部分中定义的全局函数)。

根据上下文,名称可以引用不同的值。更多 具体来说就是“语义”操作部分(以及其他程序的等效内容) 元素)定义了运行时逻辑,因此所有输入都以 Value 的形式提供。 相比之下,“限制条件”部分定义了 “编译时”逻辑,即通常在运行时之前执行的内容, 因此只有恒定输入可用作 Value,其他输入则 只能以 Placeholder 的形式提供。

名称 在“语义”中 在“限制条件”中
全局函数 Function Function
常量输入 Value Value
非常量输入 Value Placeholder
Outputs Value Placeholder
本地定义 取决于定义 取决于定义

我们来看一个 transpose 操作示例:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

对于此操作,permutation 是一个常量,因此它以 Value 的形式提供 语义和约束条件。相比之下,operandresult 在语义中以 Value 的形式提供,但在约束条件中仅以 Placeholder 的形式提供。

函数

类型的构造

没有可用于构造类型的函数。相反,我们直接 使用类型语法,因为这种语法通常更简洁。例如: (tensor<E>, tensor<E>) -> (tensor<E>),而不是 function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])

类型的函数

  • element_type 基于张量类型和量化张量类型定义, 分别会返回 TensorElementTypeQuantizedTensorElementType 对应的 TensorTypeQuantizedTensorType 的一部分。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value”是快捷方式 价格为 is_quantized(x) and quantization_dimension(x) is not None

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value是 “is_quantized(x) and quantization_dimension(x) is None”的快捷方式。

  • is_promotable(x: Type, y: Type) -> bool 会检查类型 x 是否可以升级 以输入 y。当 xyQuantizedTensorElementType 时,促销 仅适用于 storage_type。这种宣传方式具体是 目前用于归约计算的上下文中(请参阅 RFC 可了解更多详情)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value”是以下账号的快捷方式: is_quantized_tensor_element_type(x)

  • is_type_name(x: Value | Placeholder | Type) -> Value。适用于所有用户 。例如,如果 xFloatType,则 is_float(x) 会返回 true。 如果 x 是值或占位符,则此函数是以下操作的快捷方式: is_type_name(type(x))

  • max_value(x: Type) -> Value 会返回 TensorElementType。如果 x 不是 TensorElementType,则返回 None

  • min_value(x: Type) -> Value 会返回 TensorElementType。如果 x 不是 TensorElementType,则返回 None

  • member_name(x: Value | Placeholder | Type) -> Any。所有会员都可观看 所有类型的 member_name 定义。例如:tensor_element_type(x) 返回相应 TensorTypeTensorElementType 部分。 如果 x 是值或占位符,则此函数是以下操作的快捷方式: member_name(type(x))。如果 x 不是具有适当成员的类型,或者 此类值或占位符,则返回 None

  • is_empty_algorithm(*args: Type) 检查是否设置了所有点算法字段 至 None。之所以需要这样做,是因为点算法已定义了实现 默认行为,因此指定默认值是不正确的。

价值结构

  • operation_name(*xs: Value | Type) -> Value。适用于所有操作。 例如,add(lhs, rhs) 接受两个张量值 lhsrhs, 返回使用这些输入评估 add 运算的输出。 对于某些操作,例如broadcast_in_dim,其输出类型为 “负载传送”,即评估操作所需的内容。在本示例中,函数 采用这些类型作为参数。

针对值的函数

  • Python 的所有运算符和函数都可用。例如:两者都有 订阅切片 来自 Python 的表示法可用于为张量、量化张量编制索引 和元组。

  • to_destination_type(x: Value, destination_type: Type) -> Value 在 张量并基于 type(x)x 返回转换后的值 destination_type

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

我们已经讨论了如何合并 convertuniform_quantizeuniform_dequantize 操作 (#1576)。 合并后,我们不需要上面的函数,可以使用操作名称 价格为 convert

  • is_nan(x: Value) -> Value 在张量上定义,如果true 否则,x 的所有元素均为 NaNfalse。如果 x 不是张量, 返回 None

  • is_sorted(x: Value) -> Value 在张量上定义,如果true x 的元素按升序排序 索引的字典顺序,否则为 false。如果 x 不是 张量,返回 None

  • is_unique(x: Value) -> Value 在张量上定义,如果 x,则返回 true 不具有重复元素,否则为 false。如果 x 不是张量, 返回 None

  • 为所有成员定义定义了 member_name(x: Value) -> Any 占所有值的 member_name。例如,real_part(x) 会返回 RealPart。 相应 ComplexConstant 的一部分。如果 x 不是具有 相应成员返回 None

  • same(x: Value) -> Value 在张量上定义,如果true x 的元素都相等,否则为 false。如果张量的 不包含元素,这会计为“所有元素都相等”,即 函数返回 true。如果 x 不是张量,则返回 None

  • split(x: Value, num_results: Value, axis: Value) -> Value 在 张量并返回沿 axis 轴的 xnum_results 切片。 如果 x 不是张量或 dim(x, axis) % num_results != 0,则返回 None

  • is_defined_in_parent_scope(x: Value) -> Value 在字符串中定义 如果 x 是同一范围内定义的函数的名称,则返回 true 相关操作的父级函数。

  • is_namespaced_op_name(x: Value) -> Value 在字符串中定义并返回 true,如果 x 是有效的操作名称,则它遵循以下常规 表达式:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

形状计算

  • axes(x: Value | Placeholder | Type) -> Value”是以下账号的快捷方式: range(rank(x))

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value”是以下账号的快捷方式: shape(x)[axis]

  • dims(x: Value | Placeholder | Type, axes: List) -> List”是以下账号的快捷方式: list(map(lambda axis: dim(x, axis), axes))

  • index_space(x: Value | Placeholder | Type) -> Value 在张量上定义 并返回排序的相应 TensorTypesize(x) 索引 字典顺序升序,即 [0, ..., 0][0, ..., 1]、...、 shape(x) - 1。如果 x 不是张量类型、量化张量类型或值 或其中一种类型的占位符,则返回 None

  • rank(x: Value | Placeholder | Type) -> Value”是以下账号的快捷方式: size(shape(x))

  • shape(x: Value | Placeholder | Type) -> Value在“函数”部分中 类型"部分(通过 member_name)。

  • size(x: Value | Placeholder | Type) -> Value”是以下账号的快捷方式: reduce(lambda x, y: x * y, shape(x))

量化计算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type是 “element_type(baseline_type(x))”的快捷方式。

  • baseline_type 基于张量类型和量化张量类型定义, 会将它们转换为“基准”,即形状相同但 元素类型的量化参数会重置为默认值。这是 可轻松比较张量类型和量化张量类型 这也是很多情况所需要的对于量化类型,这会启用 比较类型忽略量化参数,即 shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension(针对每轴量化类型)必须全部匹配,但 scaleszero points 可能会有所不同。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize 基于量化张量类型定义,并将它们转换为 浮点张量类型。这是通过转换量化元素来实现的 表示存储类型的整数值, 使用零点和缩放比例表示的浮点值 与量化元素类型相关联。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize 基于浮点张量类型定义,并将它们转换为 量化张量类型。可以通过转换浮点值来实现 转换为相应存储类型的整数值 使用与量化元素类型相关联的零点和缩放比例。
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize 用于指定元素级计算, 量化张量。反量化,也就是将量化元素转换为 然后执行运算,然后进行量化,即 并将其转换回存储类型。目前,此函数仅 适用于每张量量化。正在进行每轴量化 (#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op 用于指定仅权重的量化 混合操作,接受浮点类型的 lhs 和量化类型的 rhs。它 将量化输入反量化为表示的类型,并执行计算 以浮点数表示。浮点 lhs 张量的元素类型以及量化 rhs 的表示类型 张量都应该是相同的。
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

网格计算

  • cross_partition(replica_groups: Value) -> Value。如需了解详情,请参阅“cross_replica” 部分。

  • cross_replica(replica_groups: Value) -> Value。如需了解详情,请参阅“cross_replica” 部分。

  • cross_replica_and_partition(replica_groups: Value) -> Value。请参阅 &quot;cross_replica_and_partition&quot;部分。

  • flattened_ids(replica_groups: Value) -> Value。请参阅“flattened_ids” 部分。

动感

StableHLO 值可以包含动态尺寸大小,例如tensor<?xi64>。 不过,StableHLO 值不能有动态的维度数量(未排名) 动感,例如tensor<*xi64>)。操作数和结果可以使用动态 即使这些尺寸存在限制,约束条件将 如果可能的话,则进行静态验证,否则将推迟到运行时并 都会导致未定义的行为。如需查看示例,请参阅下文。

一元元素级操作的形状不匹配

请参考以下玩具计划:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

这样的程序并不常见,因为很少有人知道 而不是输入的形状。尽管如此,这是一个有效的 StableHLO 计划。在此例中,无法静态验证 abs 操作, 因为操作数的确切形状未知。不过,这些形状 它们肯定是兼容的,并且可进行静态检查:? 可能变为 在运行时设为 2,这就不会出现问题。但是,? 可以 结果也会变为其他整数,在这种情况下行为未定义。

请注意,如果结果中的尺寸大小是动态的, 未定义的行为。事实上,并没有“预期”的这样就没有 不匹配。

二进制元素级操作的形状不匹配

请参考以下玩具计划:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

对于二元元素级运算,输入的形状和 结果必须在运行时一致。在编译时,静态尺寸必须相等, 否则,它们只需保持兼容即可。 如果输入中有任何维度是动态维度,则未定义 行为,因为动态尺寸可能与对应的 另一个操作数(无论是静态还是动态)的大小。如果所有输入都 则结果是否为动态并不重要:静态地 系统会对已知维度进行静态检查,而不会对动态维度进行 并施加任何限制

将输出形状作为操作数的操作的形状不匹配

请参考以下玩具计划:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

运行时形状运算数中的值必须与结果的形状匹配, 否则行为将处于未定义状态。也就是说,在运行时,%arg0 必须具有 值为 dense<[3, 4]> : tensor<2xi32>。如果形状操作数是常量,则此 可以进行静态验证如果结果形状是完全动态的,那么 不能不匹配。