Tutorial: Embedding StableHLO in SavedModel

Open in Colab Open in Kaggle

The stablehlo.savedmodel module.

This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.

Tutorial Setup

Install required dependencies

We'll be using the stablehlo nightly wheel to get StableHLO's Python APIs, and tensorflow for the SavedModel dependency.

pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels
pip install tensorflow-cpu

Embed StableHLO model in SavedModel

In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.

Define a StableHLO add module

For this tutorial we'll use a simple add model with two input arguments arg0 and bias. When packaging in SavedModel, bias will be a constant that is stored in the SavedModel, while arg0 is provided when calling the model.

MODULE_STRING = """
func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {
  %0 = stablehlo.add %arg0, %bias: tensor<1xf32>
  return %0 : tensor<1xf32>
}
"""

Parse to a StableHLO MLIR Module

Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using ir.Module.parse.

Note that all dialects in the module must be registered, otherwise parse will fail.

import mlir.ir as ir
import mlir.dialects.stablehlo as stablehlo

with ir.Context() as ctx:
  stablehlo.register_dialect(ctx)
  module = ir.Module.parse(MODULE_STRING)

print(module)
module {
  func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>
    return %0 : tensor<1xf32>
  }
}

Embed in SavedModel using stablehlo_to_tf_saved_model

StableHLO's Python wheel includes a savedmodel module to help with packaging StableHLO in SavedModels.

Packing in SavedModel requires a few details:

input_locations specify where inputs to a model live, in the saved model (InputLocation.parameter) or passed in as input arguments during invocation (InputLocation.input_arg).

state_dict can be used to specify values for the parameter arguments that live in the SavedModel. These are linked by name.

In this example, we'll specify that the second input argument is a value with name module.bias which is stored in the SavedModel with the value 2.

from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation
import numpy as np

input_locations = [
    InputLocation.input_arg(position=0),          # Parameter, non-constant
    InputLocation.parameter(name='module.bias'),  # Constant data in SavedModel
]
state_dict = {
    'module.bias': np.array([2], dtype='float32'),
}

Now we can use stablehlo_to_tf_saved_model to create the SavedModel in a path specified using the saved_model_dir argument.

from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model

stablehlo_to_tf_saved_model(
    module,
    saved_model_dir='/tmp/add_model',
    input_locations=input_locations,
    state_dict=state_dict,
)

!ls /tmp/add_model/
assets  fingerprint.pb  saved_model.pb  variables

Reload and call the SavedModel

Now we can load that SavedModel and compile using a sample input.

Here we'll just use a TF constant with the value 3.

import tensorflow as tf

restored_model = tf.saved_model.load('/tmp/add_model')
print(restored_model.f(tf.constant([3], tf.float32)))
[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]