Data Model
StableHLO programs are computations over tensors
(n-dimensional arrays), which, in the current model, are implemented using the
class Tensor
. The underlying storage class for a Tensor
object,
detail::Buffer
, stores the mlir::ShapedType
of the tensor along with a
mlir::HeapAsmResourceBlob
object representing a mutable blob of tensor
data laid out as contiguous byte array in
major-to-minor order.
detail::Buffer
objects are reference-counted to simplify memory management.
Individual elements of a tensor are represented using the Element
class, which
uses a discriminated union holding one of APInt
, APFloat
or
pair<APFloat,APFloat>
for storage. The last one is used for storing elements
with complex types.
Tensor
has the following APIs to interact with its individual elements:
Element Tensor::get(llvm::ArrayRef<int64_t> index)
: To extract an individual tensor element at multi-dimensional indexindex
asElement
object.void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
: To update anElement
objectelement
into a tensor at multi-dimensional indexindex
.
How the interpreter works
The entry function to the interpreter is
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
which does the following:
- Tracks the SSA arguments of
func
and their associated runtimeTensor
values, provided inargs
, using a symbol table map, M. - For each op within
func
, in SSACFG order:- Invokes
eval
on the op. For each SSA operand of the op, extract its runtime value from M to be provided as an argument to theeval
invocation. - Tracks the SSA result(s) of the op and the evaluated value in M.
- Invokes
The op-level eval
mentioned in (2) is responsible for implementing the
execution semantics of the op. Following is an example for stablehlo::AddOp
.
In the example, individual elements of the lhs
and rhs
tensors are pairwise
extracted as Element
objects which are then added. The result of the addition,
an Element
object, is stored in the final result
tensor.
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
Overall, the design of the interpreter is optimized for readability of
implementations of eval
functions for individual ops because it's meant to
serve as a reference implementation for StableHLO. For example, instead of
defining eval
as a template function and parameterizing it with element types,
we encapsulate details about how different element types are handled in
Element::operator+
etc., simplifying the implementation of eval
.
Using the interpreter for constant folding
We can use the interpreter mechanism to fold operations with constant operand
values. The following code snippet demonstrates an idea of the implementation
for folding stablehlo::AddOp
with floating-point typed operands:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(cast<FloatAttr>(element.getValue()).getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
At the moment, we aren't actively working on integrating the interpreter into
constant folding because we aren't planning to implement folder for StableHLO.
However, in the future, we are planning to leverage the interpreter for constant
folding in MHLO, at which point we'll improve the ergonomics of the code snippet
above (e.g. we could have a helper function which packs constant operands into
Tensor
objects and unpacks Tensor
results into OpFoldResult
).
Testing the StableHLO interpreter
The interpreter takes as inputs (A) a StableHLO program, and (B) data values to
be fed to the program, and generates output data values, which are matched
against the user-provided expected data values. The data values (B) are
hard-coded in the program itself using stablehlo.constant
operations. The
interpreter evaluates the input program. The output(s) of the op under test
is checked via checks (e.g. check.expect_eq
, check.expect_almost_eq
), as
shown below. check.expect_eq
and check.expect_eq_const
check for bitwise
equality for any supported type and check.expect_almost_eq
and
check.expect_almost_eq_const
check for near equality within a tolerance,
explained in testing guideline (G6), for floating point and complex types.
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
A test utility stablehlo-translate --interpret
(code)
is responsible for parsing the program, interpreting each function including the
operations constituting the function. We have a dedicated test-suite, consisting
of several tests exercising various runtime behaviors, for each StableHLO Op.
The tests can be found here.
Testing guidelines
(G1) Do we need to test for all the supported types for every op?
We can use a combination of the following rules to decide:
While implementing an op, if there exists code in the corresponding
eval
function to handle a particular type, then it is imperative to have test(s) to cover that type. As an example, for theadd
op, there is exclusive code to handle integer, boolean, floating-point, and complex types, and hence we need one test for each category of types.If a set of types is handled uniformly in the corresponding
eval
function, then a single test for all those types should be sufficient. As an example, for theadd
op, all the variants of integer types (si4
,u4
,si8
,u8
and so on) are handled alike usingllvm::APInt
APIs, and hence we can skip adding tests for each of those variants, and instead add a single representative test. To avoid ambiguity in selecting the representative, we should use the following guidelines:- If all the types, handled uniformly, have the same primitive type (i.e., if all are integer, or floating-point, or complex types), then choose the one with maximum bit-width.
- If all the types, handled uniformly, have a mix of primitive types, then choose the one with the following primitive type, in decreasing order of preference: integer, floating-point, boolean, complex.
(G2) How do we decide on the number of tests needed to cover an op's behavior?
The goal is to comprehensively cover the logic of the interpreter for the op (i.e. all corner cases of the implementation) with a minimal number of tests. Minimizing the number of tests is important for maintainability. The fewer tests we have, the easier it is to review them and to make sure that they comprehensively cover the op. As a result, we expect that most of the simpler ops will end up having just one test. If for some good reason comprehensive coverage is impractical, then it is fine to stop at >= 90%. This will be decided on case-by-case basis during pull request review.
(G3) How about adding tests for the interpreter infrastructure?
The interpreter infrastructure is mostly straightforward and can be added to
our trust base. The only non-trivial part is how various types are packed into
and unpacked from the underlying interpreter storage. As discussed in (G1), we
will be testing only those types of op which are handled differently. With
that it is possible that the packing/un-packing code, corresponding to different
variants of integer/floating-point types, might not get fully covered during
testing. To ensure full coverage, we can choose an op like constant
that
supports all the StableHLO element types and write exhaustive tests.
(G4) If the implementation of an op depends on other ops, should we write tests for the latter?
No. For example, the implementation of batch_norm_grad
can be based on
divide
, subtract
, multiply
and others. We should avoid testing the latter
ops while testing the former.
(G5) Should we write tests to exercise the implementation-defined / undefined behaviors?
We should not write tests which exercise the implementation-defined or undefined behaviors of the op. Tests exercising implementation-defined behaviors demonstrate a local behavior of the interpreter which should not be generalized. Tests exercising undefined behavior do not contribute towards the understanding of the op's behavior.
(G6) While writing tests for floating-point types, to what precision does the expected result need to be specified in checks?
For elementary operations (addition, subtraction, multiplication, division, and
square), an implementation following IEEE specification is expected to provide a
rounded result within 0.5 ULP of the mathematically exact result. That said, we
can safely imagine the expected result coming out of these operations to be at
most 1 ULP apart. However, this may not work for transcendental functions
(sine
, cosine
, etc.) for which the precision guarantees are
implementation-defined (rationale).
The current implementation uses a "one-size-fits-all" tolerance value of 0.0001. The following example demonstrates the above tolerance in action.
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
This is just the first step in testing the numerical accuracy of StableHLO ops. At the moment, this is an underspecified area of the StableHLO spec, and there is ongoing work to figure it out #1156 based on our experience using StableHLO in practice and on feedback from stakeholders. As this works proceeds, we will update the infrastructure accordingly.
(G7) Anything about the coding-style of the tests?
- Make sure to use the actual name of the inputs/outputs instead of defaulting to SSA values (e.g. %0, %1, etc.)
- Make sure the tests use pretty-printed format, if it exists.
(G8) Should we include the example already provided in the spec? Yes (for completeness of testing).