The stablehlo.savedmodel
module.
This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.
Tutorial Setup
Install required dependencies
We'll be using the stablehlo
nightly wheel to get StableHLO's Python APIs, and tensorflow
for the SavedModel dependency.
pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels
pip install tensorflow-cpu
Embed StableHLO model in SavedModel
In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.
Define a StableHLO add
module
For this tutorial we'll use a simple add
model with two input arguments arg0
and bias
. When packaging in SavedModel, bias
will be a constant that is stored in the SavedModel, while arg0
is provided when calling the model.
MODULE_STRING = """
func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {
%0 = stablehlo.add %arg0, %bias: tensor<1xf32>
return %0 : tensor<1xf32>
}
"""
Parse to a StableHLO MLIR Module
Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using ir.Module.parse
.
Note that all dialects in the module must be registered, otherwise parse
will fail.
import mlir.ir as ir
import mlir.dialects.stablehlo as stablehlo
with ir.Context() as ctx:
stablehlo.register_dialect(ctx)
module = ir.Module.parse(MODULE_STRING)
print(module)
module { func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> return %0 : tensor<1xf32> } }
Embed in SavedModel using stablehlo_to_tf_saved_model
StableHLO's Python wheel includes a savedmodel
module to help with packaging StableHLO in SavedModels.
Packing in SavedModel requires a few details:
input_locations
specify where inputs to a model live, in the saved model (InputLocation.parameter
) or passed in as input arguments during invocation (InputLocation.input_arg
).
state_dict
can be used to specify values for the parameter
arguments that live in the SavedModel. These are linked by name
.
In this example, we'll specify that the second input argument is a value with name module.bias
which is stored in the SavedModel with the value 2
.
from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation
import numpy as np
input_locations = [
InputLocation.input_arg(position=0), # Parameter, non-constant
InputLocation.parameter(name='module.bias'), # Constant data in SavedModel
]
state_dict = {
'module.bias': np.array([2], dtype='float32'),
}
Now we can use stablehlo_to_tf_saved_model
to create the SavedModel in a path specified using the saved_model_dir
argument.
from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model
stablehlo_to_tf_saved_model(
module,
saved_model_dir='/tmp/add_model',
input_locations=input_locations,
state_dict=state_dict,
target_version='1.8.5',
)
!ls /tmp/add_model/
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using a sample input.
Here we'll just use a TF constant with the value 3
.
import tensorflow as tf
restored_model = tf.saved_model.load('/tmp/add_model')
print(restored_model.f(tf.constant([3], tf.float32)))
[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]