StableHLO 中的活力感

动态性 RFC 中更正式地阐述了动态性的当前状态,本页面将提供该 RFC 的高级别概览,并讨论用于与动态程序交互的重要 API 和工具。

动态广告素材术语和支持服务概览

首先,为了涵盖本文档中将出现的一些术语,以及简要介绍 StableHLO 对这些术语的支持,请参阅下表:

动态维度

动态维度是指维度大小未知的任何维度。 在 StableHLO 中,我们使用 ? 表示动态维度,即 tensor<16x?xf32>

有界限的动态性

有界动态是指值具有已知上限的动态维度。通常,这对于在执行期间填充张量很有用。 在 StableHLO 中,我们使用 #stablehlo.bounds 作为张量编码来表示有界动态性,即一个动态维度有界为 16 且另一个没有界限的秩 2 张量可以表示为 tensor<?x?xf32, #stablehlo.bounds<16, ?>>

StableHLO 能够表示有界动态性,但框架支持有限,最初是在 TensorFlow 中,在 PyTorch/XLA 中也有一些支持。

无界限的动态性

顾名思义,无界动态是指大小没有已知界限的动态维度。这种动态性在 StableHLO 中非常常见,JAX、PyTorch/XLA 和 TF 都支持这种动态性,通常用于导出具有动态批次大小或序列长度的模型。

在 StableHLO 中,我们只需省略这种形式的动态性的边界编码,即 tensor<?x?xf32>

形状多态性

形状多态性是我们从 JAX 继承的术语

形状多态性有两项关键含义:

  1. 程序轨迹中的所有动态性都可追溯到其输入实参。
  2. 所有动态性仅与张量形状有关,即不依赖于数据。

有了这两条规则,一旦程序的静态形状已知,我们就可以获取动态程序并将其完全细化为静态程序以进行编译(请参阅“用于细化动态程序的编译器传递”)。

一般来说,如果已知实参形状可以生成完全静态的程序,则形状多态性会使用无界动态性,无需猜测如何限定值。

取决于数据的动态性

数据相关动态是指与张量内的数据相关的动态维度大小。一个规范示例是 nonzeros 函数,该函数会返回张量值中所有 0 元素的索引。如果不评估数据,就无法知道形状,但通常可以使用有界动态性进行编译,在潜在的输出张量大小上花费额外的内存。

许多依赖于数据的动态操作都可以使用有界动态性进行建模,即指定张量大小的上限,硬件通常会通过张量填充来实现这一点。目前,PyTorch/XLA 和 TensorFlow 在一定程度上支持数据相关动态性,但 JAX 目前不会跟踪导致数据相关动态性的操作。

导出具有动态维度的计划

如需了解如何导出具有动态批大小或序列长度的程序,请参阅我们的 StableHLO 教程:

用于优化动态程序的编译器传递

移除动态传递流水线

有一些实用的通道可用于优化形状,它们都方便地捆绑在一个通道流水线 createStablehloRemoveDynamismPipeline 中:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

用于细化动态效果的各个通道

就单个而言,往往有助于细化形状的通道包括:

如需了解最新信息和示例,请参阅链接的文档。

示例:动态性有何用处?如何使用?

动态性有很多用途,这里我们主要关注形状多态性的常见用例 - 创建灵活的导出模型表示形式,通常用于表示动态批次大小或序列长度。

静态 add_one 模型

我们将使用以下简单的 add_one 模型来演示这一点:

def add_one(x):
  return x + 1

使用 tensor<4xf32> 进行跟踪时,我们将获得以下 StableHLO 程序:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

此模型将适用于具有 tensor<4xf32> 形参的输入实参。如果我们更改了批次大小或序列长度,就需要重新跟踪源代码并重新降级到 StableHLO,但我们甚至无法保证是否仍能访问源代码!

动态 add_one 模型

这时,形状多态动态性就派上用场了。而 JAX 和 PyTorch/XLA 可以发出具有动态有效 IR 的 add_one 模型,该模型将广播常量以匹配动态输入形状,如下所示:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

这种模型表示法更加灵活,并且允许延迟指定批次大小或序列长度等值。此模型可以部署在支持动态形状的平台(例如 AI Edge)上,也可以使用本文档中提及的动态传递进行优化。

优化动态模型

例如,以下传递顺序可以完全优化此程序:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

该计划的逐步转变方式如下:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}