数据模型
StableHLO 程序是对张量(N 维数组)进行的计算,在当前模型中,是使用 Tensor
类实现的。Tensor
对象的底层存储类别 detail::Buffer
会存储张量的 mlir::ShapedType
以及一个 mlir::HeapAsmResourceBlob
对象,该对象表示一个可变的张量数据 blob,该 blob 的可变张量数据以主要顺序的形式布局为连续字节数组。对 detail::Buffer
对象进行引用计数,以简化内存管理。
张量的各个元素使用 Element
类表示,该类使用判别并集,其中存储了 APInt
、APFloat
或 pair<APFloat,APFloat>
。最后一个函数用于存储具有复杂类型的元素。
Tensor
具有以下 API 来与其各个元素进行交互:
Element Tensor::get(llvm::ArrayRef<int64_t> index)
:提取多维索引index
处的单个张量元素作为Element
对象。void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
:将Element
对象element
更新为多维索引index
处的张量。
解释器的运作方式
解释器的输入函数是
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
它会执行以下操作:
- 使用符号表映射 M 跟踪
func
的 SSA 参数及其关联的运行时Tensor
值(在args
中提供)。 - 对于
func
中的每个操作(按 SSACFG 顺序):- 对操作调用
eval
。对于操作的每个 SSA 运算数,从 M 中提取其运行时值,以作为参数提供给eval
调用。 - 跟踪操作的 SSA 结果和以 M 为单位的评估值。
- 对操作调用
(2) 中提到的操作级 eval
负责实现操作的执行语义。以下是 stablehlo::AddOp
的示例。在此示例中,lhs
和 rhs
张量的各个元素会成对提取为 Element
对象,然后添加这些对象。加法的结果(即 Element
对象)存储在最终的 result
张量中。
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
总体而言,解释器的设计针对各个操作的 eval
函数实现的可读性进行了优化,因为它要用作 StableHLO 的参考实现。例如,我们没有将 eval
定义为模板函数并使用元素类型对其进行参数化,而是在 Element::operator+
等中封装了有关如何处理不同元素类型的详细信息,从而简化了 eval
的实现。
使用解释器进行持续折叠
我们可以使用解释器机制来折叠包含常量运算数值的运算。以下代码段演示了使用浮点类型运算数折叠 stablehlo::AddOp
的实现方式:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(element.getValue().cast<FloatAttr>().getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
目前,我们没有积极致力于将解释器集成到常量折叠中,因为我们不打算为 StableHLO 实现文件夹。不过,未来,我们计划利用解释器在 MHLO 中进行常量折叠,届时我们将改进上述代码段的工效学设计(例如,我们可以使用一个辅助函数,将常量运算数打包到 Tensor
对象中,并将 Tensor
结果解压缩到 OpFoldResult
中)。
测试 StableHLO 解释器
解释器将 (A) StableHLO 程序和 (B) 数据值作为输入并馈送给该程序,并生成输出数据值,这些数据值与用户提供的预期数据值相匹配。数据值 (B) 使用 stablehlo.constant
操作在程序本身中硬编码。解释器评估输入程序。被测操作的输出通过检查(例如 check.expect_eq
、check.expect_almost_eq
)进行检查,如下所示。check.expect_eq
和 check.expect_eq_const
会检查任何受支持类型的按位相等,而 check.expect_almost_eq
和 check.expect_almost_eq_const
可检查是否在公差范围内(如测试指南 (G6) 中所述)对于浮点和复杂类型是否相等。
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
测试实用程序 stablehlo-translate --interpret
(代码)负责解析程序,以解读每个函数(包括构成函数的运算)。我们有一个专用的测试套件,其中包含针对每个 StableHLO 操作执行各种运行时行为的多个测试。您可以在此处找到这些测试(例如 resolve_*.mlir)。
测试准则
(G1) 我们是否需要针对每个操作测试所有受支持的类型?
我们可以综合使用以下规则来确定:
在实现操作时,如果相应的
eval
函数中有用于处理特定类型的代码,则必须使用涵盖该类型的测试。例如,对于add
操作,存在用于处理整数、布尔值、浮点和复杂类型的独占代码,因此我们需要针对每种类型的类型进行一项测试。如果在相应的
eval
函数中统一处理一组类型,那么针对所有这些类型运行一次测试应该就足够了。例如,对于add
操作,整数类型(si4
、u4
、si8
、u8
等)的所有变体都使用llvm::APInt
API 以同样的方式处理,因此我们可以跳过为每个变体添加测试,而改为添加一个代表性测试。为避免在选择代表时造成歧义,我们应遵循以下准则:- 如果以统一方式处理的所有类型具有相同的基元类型(即,如果所有类型均为整数、浮点或复杂类型),请选择位宽最大的类型。
- 如果以统一方式处理的所有类型都有混合基元类型,则选择具有以下基元类型的类型,并按偏好设置降序排列:整数、浮点数、布尔值、复杂类型。
(G2) 我们如何确定涵盖操作行为所需的测试次数?
目标是用最少的测试次数全面涵盖操作解释器的逻辑(即实现的所有极端情况)。尽量减少测试次数对于可维护性非常重要。测试越少,审核它们就越容易,并确保它们能够全面涵盖相应操作。因此,我们预计大多数较简单的操作最终只会有一个测试。如果出于某种合理原因,全面覆盖不切实际,可以将 >= 90% 的覆盖率控制在 90% 以上。这将在拉取请求审核期间根据具体情况决定。
(G3) 如何为解释器基础架构添加测试?
解释器基础架构基本上比较简单,可以添加到我们的信任库中。唯一重要的部分是如何将各种类型打包到底层解释器存储空间并从中解压缩。如 (G1) 中所述,我们只会测试处理方式不同的那些操作类型。因此,测试期间可能不会完全涵盖与整数/浮点类型的不同变体对应的打包/解压缩代码。为了确保全面覆盖,我们可以选择支持所有 StableHLO 元素类型的操作(例如 constant
),并编写详尽的测试。
(G4) 如果操作的实现依赖于其他操作,我们是否应该针对后者编写测试?
否。例如,batch_norm_grad
的实现可以基于 divide
、subtract
、multiply
等。在测试前者时,我们应避免测试后者。
(G5) 我们是否应该编写测试来实施实现定义 / 未定义的行为?
我们不应编写执行操作的实现定义或未定义行为的测试。执行实现定义的行为的测试会演示不应泛化的解释器的本地行为。执行未定义行为的测试对理解操作行为没有帮助。
(G6) 在为浮点类型编写测试时,需要在检查时将预期结果指定到什么精度?
对于基本运算(加法、减法、乘法、除法和平方),遵循 IEEE 规范的实现提供的舍入结果应在数学上精确的结果的 0.5 ULP 范围内。也就是说,我们可以放心地想象,这些操作的预期结果最多相差 1 个 ULP。但是,这可能不适用于精度保证由实现定义(理由)的先后函数(sine
、cosine
等)。
当前的实现使用的公差值为 0.0001。 以下示例演示了上述公差的实际运用。
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
这只是测试 StableHLO 操作数值准确性的第一步。 目前,在 StableHLO 规范中,这方面没有明确规定,我们仍在不断努力,根据我们使用 StableHLO 的实际经验和利益相关方的反馈,努力找出 #1156。随着相关工作的推进,我们将相应地更新基础架构。
(G7) 您对测试的编码风格有什么要求吗?
- 请务必使用输入/输出的实际名称,而不是默认使用 SSA 值(例如 %0、%1 等)。
- 确保测试使用美观输出的格式(如果存在)。
(G8) 是否应添加规范中提供的示例? 是(为了保证测试的完整性)。