数据模型
StableHLO 程序是基于张量的计算
(n 维数组),在当前模型中是使用
Tensor
类。Tensor
对象的底层存储类别,
detail::Buffer
,存储张量的 mlir::ShapedType
以及
表示可变张量 blob 的 mlir::HeapAsmResourceBlob
对象
以连续字节数组的形式
从大到小顺序。
为简化内存管理,系统会对 detail::Buffer
对象进行引用计数。
张量的各个元素均使用 Element
类表示,
使用歧视性并集,包含 APInt
、APFloat
或
pair<APFloat,APFloat>
存储空间。最后一个用于存储元素
具有多种复杂类型。
Tensor
具有以下 API 与其各个元素进行交互:
Element Tensor::get(llvm::ArrayRef<int64_t> index)
:如需提取 多维索引index
处的各个张量元素,如Element
对象。void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
: 将Element
对象element
更新为多维张量 索引:index
。
翻译工具的工作原理
解释器的条目函数是
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
该命令将执行以下操作:
- 跟踪
func
的 SSA 参数及其关联的运行时Tensor
值,使用符号表映射 M,在args
中提供。 - 对于
func
中的每个操作(按 SSACFG 顺序排列): <ph type="x-smartling-placeholder">- </ph>
- 对操作调用
eval
。对于操作的每个 SSA 操作数,提取其 来自 M 的运行时值作为参数提供给eval
调用。 - 跟踪操作的 SSA 结果和 M 中的评估值。
- 对操作调用
(2) 中提到的操作级 eval
负责实现
操作的执行语义。以下是 stablehlo::AddOp
的示例。
在该示例中,lhs
和 rhs
张量的各个元素成对显示
提取为 Element
对象,然后再进行添加。相加的结果是
Element
对象存储在最终 result
张量中。
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
总体而言,解释器的设计针对
单个 op 的 eval
函数的实现,因为它旨在
用作 StableHLO 的参考实现。例如,将
将 eval
定义为模板函数,并使用元素类型对其进行参数化;
我们封装了有关在 Google Analytics 中
Element::operator+
等,从而简化 eval
的实现。
使用解释器进行常量折叠
我们可以使用解释器机制来折叠具有常量运算数的运算
值。以下代码段说明了在什么情况下
用于对带有浮点型运算数的 stablehlo::AddOp
折叠:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(cast<FloatAttr>(element.getValue()).getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
目前,我们没有积极致力于将翻译工具集成到
常量折叠,因为我们不打算为 StableHLO 实现文件夹。
不过,我们计划在未来利用该解释器来
折叠,此时我们将改进代码段的工效学设计
(例如,我们可以创建一个辅助函数,将常量运算数
Tensor
对象并将 Tensor
结果解压缩到 OpFoldResult
中。
测试 StableHLO 解释器
解释器将 (A) StableHLO 程序和 (B) 数据值作为输入
提供给程序,并生成输出数据值,
与用户提供的数据值对比。数据值 (B) 为
使用 stablehlo.constant
操作在程序本身中进行硬编码。通过
解释器评估输入程序。被测操作的输出
通过检查(例如 check.expect_eq
、check.expect_almost_eq
)进行检查,
如下所示。check.expect_eq
和 check.expect_eq_const
检查按位
相等,任何支持的类型以及 check.expect_almost_eq
和
check.expect_almost_eq_const
检查在容限内是否接近相等;
有关浮点数和复杂类型的测试指南 (G6) 中进行了详细说明。
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
测试实用程序 stablehlo-translate --interpret
(代码)
负责解析该程序,解释每个函数(包括
构成函数的操作。我们有专门的测试套件
针对每个 StableHLO Op 执行各种运行时行为的多项测试。
您可以在此处查看这些测试。
检测指南
(G1) 我们是否需要测试每个操作的所有受支持类型?
我们可以结合使用以下规则来做出决定:
实现操作时,如果相应的
eval
中存在代码 函数处理特定类型,则必须有测试 来涵盖该类型例如,对于add
操作,有 处理整数、布尔值、浮点和复杂类型,因此我们 每个类型都需要一项测试。如果在相应的
eval
函数中统一处理一组类型, 那么针对所有这些类型进行一次测试就应该足够了例如 对于add
操作,整数类型(si4
、u4
、si8
、u8
)的所有变体 等)均使用llvm::APInt
API 进行处理,因此可以跳过 分别为每个变体添加测试, 有代表性测试。为了避免在选择代表性方面产生歧义, 应遵循以下准则:- 如果所有类型都得到统一处理,它们具有相同的基元类型 (即,如果它们均为整数、浮点或复杂类型),则 请选择位宽最大的那个。
- 如果所有类型都统一处理,并且混合了基元类型,那么 选择具有以下基元类型的那一个,按照从高到低的顺序 偏好设置:整数、浮点、布尔值、复杂。
(G2) 我们如何确定覆盖操作所需的测试数量 行为?
目标是全面介绍操作解释器的逻辑 (即实现的所有极端情况)。 最大限度地减少测试数量对于可维护性非常重要。测试数量越少 就越容易进行审核并确保它们 全面介绍操作因此,我们预计大多数 运维人员最终只会有一项测试。如果出于某种充分的理由,全面详解 覆盖率不切实际,则最好停止在 >= 90%。这将由 根据具体情况在拉取请求审核期间提供。
(G3) 如何为解释器基础架构添加测试?
解释器基础架构基本上非常简单,可以添加到
我们的信任库。唯一重要的部分是如何将各种类型打包到
并从底层解释器存储空间中提取出来。如 (G1) 中所述,
将仅测试处理方式不同的操作类型。包含
打包/解打包代码可能对应不同的
整数/浮点类型的变体,在
测试。为确保完全覆盖,我们可以选择类似 constant
的操作,
支持所有 StableHLO 元素类型,并编写详尽的测试。
(G4) 如果一项操作的实现依赖于其他操作,我们应该写 是否对后者进行了测试?
不需要。例如,batch_norm_grad
的实现可以基于
divide
、subtract
、multiply
和其他人。对于后者
而对前者进行测试
(G5) 我们应该编写测试以执行实现定义的 / 未定义的 哪些行为?
我们不应在编写测试时执行实现定义的 未定义的操作行为。执行实现定义的行为的测试 展现解释器的本地行为,而这种行为 泛化处理。执行未定义行为的测试不会产生 对操作行为的理解
(G6) 在针对浮点类型编写测试时, 是否需要在检查中指定预期结果?
对于基本运算(加法、减法、乘法、除法和
正方形),则遵循 IEEE 规范的实现应提供
四舍五入的结果与数学精确结果相差不超过 0.5 ULP。尽管如此,
可以放心地想象,这些操作的预期结果是
最多相隔 1 个 ULP。不过,这种方法可能不适用于先验函数
(sine
、cosine
等),其中精度保证适用
实现定义的(理由)。
当前实现使用的是“通用方案”设置为 0.0001。 以下示例演示了上述容限的实际运用。
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
这只是测试 StableHLO 运算的数字准确性的第一步。 目前,这是 StableHLO 规范中未指定的区域, 正在进行这项研究 #1156 根据我们在实践中使用 StableHLO 的经验,以及 相关方。随着工作的进行,我们会更新基础架构 。
(G7) 对测试的编码风格有什么建议吗?
- 请务必使用输入/输出的实际名称,而不是使用默认值 SSA 值(例如 %0、%1 等)
- 请确保测试使用整齐打印格式(如果存在)。
(G8) 我们是否应添加规范中提供的示例? 是(为了完成测试)。