Interpreter Design

Data Model

StableHLO programs are computations over tensors (n-dimensional arrays), which, in the current model, are implemented using the class Tensor. The underlying storage class for a Tensor object, detail::Buffer, stores the mlir::ShapedType of the tensor along with a mlir::HeapAsmResourceBlob object representing a mutable blob of tensor data laid out as contiguous byte array in major-to-minor order. detail::Buffer objects are reference-counted to simplify memory management.

Individual elements of a tensor are represented using the Element class, which uses a discriminated union holding one of APInt, APFloat or pair<APFloat,APFloat> for storage. The last one is used for storing elements with complex types.

Tensor has the following APIs to interact with its individual elements:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): To extract an individual tensor element at multi-dimensional index index as Element object.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: To update an Element object element into a tensor at multi-dimensional index index.

How the interpreter works

The entry function to the interpreter is

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

which does the following:

  1. Tracks the SSA arguments of func and their associated runtime Tensor values, provided in args, using a symbol table map, M.
  2. For each op within func, in SSACFG order:
    • Invokes eval on the op. For each SSA operand of the op, extract its runtime value from M to be provided as an argument to the eval invocation.
    • Tracks the SSA result(s) of the op and the evaluated value in M.

The op-level eval mentioned in (2) is responsible for implementing the execution semantics of the op. Following is an example for stablehlo::AddOp. In the example, individual elements of the lhs and rhs tensors are pairwise extracted as Element objects which are then added. The result of the addition, an Element object, is stored in the final result tensor.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

Overall, the design of the interpreter is optimized for readability of implementations of eval functions for individual ops because it's meant to serve as a reference implementation for StableHLO. For example, instead of defining eval as a template function and parameterizing it with element types, we encapsulate details about how different element types are handled in Element::operator+ etc., simplifying the implementation of eval.

Using the interpreter for constant folding

We can use the interpreter mechanism to fold operations with constant operand values. The following code snippet demonstrates an idea of the implementation for folding stablehlo::AddOp with floating-point typed operands:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
  DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(cast<FloatAttr>(element.getValue()).getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

At the moment, we aren't actively working on integrating the interpreter into constant folding because we aren't planning to implement folder for StableHLO. However, in the future, we are planning to leverage the interpreter for constant folding in MHLO, at which point we'll improve the ergonomics of the code snippet above (e.g. we could have a helper function which packs constant operands into Tensor objects and unpacks Tensor results into OpFoldResult).

Testing the StableHLO interpreter

The interpreter takes as inputs (A) a StableHLO program, and (B) data values to be fed to the program, and generates output data values, which are matched against the user-provided expected data values. The data values (B) are hard-coded in the program itself using stablehlo.constant operations. The interpreter evaluates the input program. The output(s) of the op under test is checked via checks (e.g. check.expect_eq, check.expect_almost_eq), as shown below. check.expect_eq and check.expect_eq_const check for bitwise equality for any supported type and check.expect_almost_eq and check.expect_almost_eq_const check for near equality within a tolerance, explained in testing guideline (G6), for floating point and complex types.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

A test utility stablehlo-translate --interpret (code) is responsible for parsing the program, interpreting each function including the operations constituting the function. We have a dedicated test-suite, consisting of several tests exercising various runtime behaviors, for each StableHLO Op. The tests can be found here.

Testing guidelines

(G1) Do we need to test for all the supported types for every op?

We can use a combination of the following rules to decide:

  1. While implementing an op, if there exists code in the corresponding eval function to handle a particular type, then it is imperative to have test(s) to cover that type. As an example, for the add op, there is exclusive code to handle integer, boolean, floating-point, and complex types, and hence we need one test for each category of types.

  2. If a set of types is handled uniformly in the corresponding eval function, then a single test for all those types should be sufficient. As an example, for the add op, all the variants of integer types (si4, u4, si8, u8 and so on) are handled alike using llvm::APInt APIs, and hence we can skip adding tests for each of those variants, and instead add a single representative test. To avoid ambiguity in selecting the representative, we should use the following guidelines:

    • If all the types, handled uniformly, have the same primitive type (i.e., if all are integer, or floating-point, or complex types), then choose the one with maximum bit-width.
    • If all the types, handled uniformly, have a mix of primitive types, then choose the one with the following primitive type, in decreasing order of preference: integer, floating-point, boolean, complex.

(G2) How do we decide on the number of tests needed to cover an op's behavior?

The goal is to comprehensively cover the logic of the interpreter for the op (i.e. all corner cases of the implementation) with a minimal number of tests. Minimizing the number of tests is important for maintainability. The fewer tests we have, the easier it is to review them and to make sure that they comprehensively cover the op. As a result, we expect that most of the simpler ops will end up having just one test. If for some good reason comprehensive coverage is impractical, then it is fine to stop at >= 90%. This will be decided on case-by-case basis during pull request review.

(G3) How about adding tests for the interpreter infrastructure?

The interpreter infrastructure is mostly straightforward and can be added to our trust base. The only non-trivial part is how various types are packed into and unpacked from the underlying interpreter storage. As discussed in (G1), we will be testing only those types of op which are handled differently. With that it is possible that the packing/un-packing code, corresponding to different variants of integer/floating-point types, might not get fully covered during testing. To ensure full coverage, we can choose an op like constant that supports all the StableHLO element types and write exhaustive tests.

(G4) If the implementation of an op depends on other ops, should we write tests for the latter?

No. For example, the implementation of batch_norm_grad can be based on divide, subtract, multiply and others. We should avoid testing the latter ops while testing the former.

(G5) Should we write tests to exercise the implementation-defined / undefined behaviors?

We should not write tests which exercise the implementation-defined or undefined behaviors of the op. Tests exercising implementation-defined behaviors demonstrate a local behavior of the interpreter which should not be generalized. Tests exercising undefined behavior do not contribute towards the understanding of the op's behavior.

(G6) While writing tests for floating-point types, to what precision does the expected result need to be specified in checks?

For elementary operations (addition, subtraction, multiplication, division, and square), an implementation following IEEE specification is expected to provide a rounded result within 0.5 ULP of the mathematically exact result. That said, we can safely imagine the expected result coming out of these operations to be at most 1 ULP apart. However, this may not work for transcendental functions (sine, cosine, etc.) for which the precision guarantees are implementation-defined (rationale).

The current implementation uses a "one-size-fits-all" tolerance value of 0.0001. The following example demonstrates the above tolerance in action.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

This is just the first step in testing the numerical accuracy of StableHLO ops. At the moment, this is an underspecified area of the StableHLO spec, and there is ongoing work to figure it out #1156 based on our experience using StableHLO in practice and on feedback from stakeholders. As this works proceeds, we will update the infrastructure accordingly.

(G7) Anything about the coding-style of the tests?

  1. Make sure to use the actual name of the inputs/outputs instead of defaulting to SSA values (e.g. %0, %1, etc.)
  2. Make sure the tests use pretty-printed format, if it exists.

(G8) Should we include the example already provided in the spec? Yes (for completeness of testing).