Dynamism RFC 对动态的现状进行了更正式的阐述,本页将简要介绍 RFC,并讨论与动态程序交互的重要 API 和工具。
动态效果术语和支持概览
首先,介绍本文档中出现的几个术语,并简要说明这些术语在 StableHLO 中的支持:
动态维度
动态维度是指维度大小未知的任何维度。
在 StableHLO 中,我们使用 ?
(即 tensor<16x?xf32>
)表示动态维度。
无限活力
受限的动态性是指值具有已知上限的动态维度。通常,这对于在执行期间填充张量很有用。在 StableHLO 中,我们使用 #stablehlo.bounds
作为张量编码来表示受限的动态性,即秩为 2 的张量,其中一个动态维度上限为 16,另一个无上限,可表示为 tensor<?x?xf32, #stablehlo.bounds<16, ?>>
。
StableHLO 能够表示有界限动态,但提供的框架支持有限,源自 TensorFlow,在 PyTorch/XLA 中也提供一些支持。
活力无限
顾名思义,无限动态性是指尺寸没有已知边界的动态维度。这种类型的动态性在 StableHLO 中非常常见,并且支持 JAX、PyTorch/XLA 和 TF,通常用于导出具有动态批处理大小或序列长度的模型。
在 StableHLO 中,我们只需省略这种动态形式的边界编码,即 tensor<?x?xf32>
。
形状多态性
形状多态性是我们从 JAX 继承的术语。
形状多态性有两个关键影响:
- 程序中的所有动态性都源自其输入参数。
- 所有动态性仅与张量形状有关,即不依赖于数据。
有了这两个规则,一旦知道程序的静态形状,我们就可以将动态程序完全优化为静态程序以进行编译(请参阅“用于优化动态程序的编译器传递”)。
通常,形状多态性使用无界动态性,如果已知参数形状可以导致完全静态的程序,则无需猜测如何对值进行边界限定。
数据依赖型动态性
数据依赖型动态性是指与张量内数据相关的动态维度大小。规范示例是 nonzeros
函数,用于返回张量值中所有 0
元素的索引。如果不评估数据,就无法知道形状,但通常可以使用受限的动态性进行编译,从而为潜在的输出张量大小分配额外的内存。
许多数据依赖型动态运算都可以使用边界型动态性进行建模,其中指定了张量大小的上限,并且硬件通常会通过张量填充来实现此操作。如今,PyTorch/XLA 和 TensorFlow 中对依赖数据的动态有一定的支持,但 JAX 目前未跟踪可导致依赖数据的动态的操作。
导出具有动态维度的程序
如需了解如何导出具有动态批量大小或序列长度的程序,请参阅我们的 StableHLO 教程:
用于优化动态程序的编译器传递
移除了动态性传递流水线
有几个用于优化形状的实用传递,它们都捆绑在传递流水线 createStablehloRemoveDynamismPipeline
中,非常方便:
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);
用于优化动态性的个别卡券
以下各个传递通常对形状优化很有用:
stablehlo-refine-arguments
,用于将输入参数替换为具体的张量类型。stablehlo-refine-shapes
来将新的输入参数形状信息传播到整个程序。stablehlo-canonicalize-dynamism
,用于将动态运算替换为其静态变体。
如需了解最新信息和示例,请参阅链接的文档。
示例:动态性有何用处?如何使用?
Dynamism 有很多用途,在这里,我们主要关注形状多态性的常见用例,即创建灵活的导出模型表示法,通常用于表示动态批次大小或序列长度。
静态 add_one 模型
我们将使用以下简单的 add_one
模型来演示这一点:
def add_one(x):
return x + 1
使用 tensor<4xf32>
进行跟踪时,我们将获得以下 StableHLO 程序:
// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
%0 = stablehlo.add %arg0, %cst : tensor<4xf32>
return %0 : tensor<4xf32>
}
此模型仅适用于形状为 tensor<4xf32>
的输入参数。如果我们更改了批次大小或序列长度,则需要重新追溯源代码并重新降到 StableHLO,但无法保证我们仍然能够访问源代码!
动态 add_one 模型
这正是形状多态动态的用武之地。相反,JAX 和 PyTorch/XLA 可以发出具有动态有效 IR 的 add_one
模型,该 IR 将广播常量以匹配动态输入形状,如下所示:
// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
%2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %2 : tensor<?xf32>
return %3 : tensor<?xf32>
}
这种模型表示法更加灵活,并且允许延迟指定批处理大小或序列长度等值。此模型可部署在支持动态形状的平台(例如 AI Edge),也可以使用本文档中提到的动态性传递进行优化。
优化动态模型
例如,以下传递顺序可以完全优化此程序:
stablehlo-opt add_one_dynamic.mlir \
--stablehlo-refine-arguments='types=tensor<16xf32>' \
--stablehlo-refine-shapes \
--stablehlo-canonicalize-dynamism
以下是程序的增量转换方式:
// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
...
%3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %0, %3 : tensor<?xf32>
return %4 : tensor<?xf32>
}
// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<16> : tensor<1xi32>
%0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
%0 = stablehlo.add %arg0, %cst : tensor<16xf32>
return %0 : tensor<16xf32>
}