Tutorial: Exporting StableHLO from JAX

Open in Colab Open in Kaggle

JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.

Tutorial Setup

Install required dependencies

We use jax and jaxlib (JAX's support library with compiled binaries), along with flax and transformers for some models to export. We also need to install tensorflow to work with SavedModel, and recommend using tensorflow-cpu or tf-nightly for this tutorial.

pip install -U jax jaxlib flax transformers tensorflow-cpu

Define get_stablehlo_asm to help with MLIR printing

Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability.

Export JAX model to StableHLO using jax.export

In this section we'll export a basic JAX function and a Flax model to StableHLO.

The preferred API for export is jax.export. The function to export must be JIT transformed, specifically a result of jax.jit, to be exported to StableHLO.

Export basic JAX model to StableHLO

Let's start by exporting a basic plus function to StableHLO, using np.int32 argument types to trace the function.

Export requires specifying shapes using jax.ShapeDtypeStruct, which can be constructed from NumPy values.

import jax
from jax import export
import jax.numpy as jnp
import numpy as np

# Create a JIT-transformed function
@jax.jit
def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]

# Export the function to StableHLO
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
print(get_stablehlo_asm(stablehlo_add))
module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>
    return %0 : tensor<i32>
  }
}

Export Hugging Face FlaxResNet18 to StableHLO

Now let's look at a simple model that appears in the wild, resnet18.

We'll export a flax model from the Hugging Face transformers ResNet page, FlaxResNetModel. This steps setup was copied from the Hugging Face documentation.

The documentation also states: "Finally, this model supports inherent JAX features such as: Just-In-Time (JIT) compilation ..." which means it is perfect for export.

Similar to our very basic example, our steps for export are:

  1. Instantiate a callable (model/function)
  2. JIT-transform it with jax.jit
  3. Specify shapes for export using jax.ShapeDtypeStruct on NumPy values
  4. Use the JAX export API to get a StableHLO module
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np

# Construct jit-transformed flax model with sample inputs
resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
resnet18_jit = jax.jit(resnet18)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)

# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x3x224x224xf32>) -> (tensor<1x512x7x7xf32> {jax.result_info = "[0]"}, tensor<1x512x1x1xf32> {jax.result_info = "[1]"}) {
    %c = stablehlo.constant dense<49> : tensor<i32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<f32>
    %cst_2 = stablehlo.constant dense_reso 
...
 func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32>
    %1 = stablehlo.maximum %arg0, %0 : tensor<1x7x7x512xf32>
    return %1 : tensor<1x7x7x512xf32>
  }
}

Export with dynamic batch size

Now let's export that same model with a dynamic batch size!

In the first example, we used an input shape of tensor<1x3x224x224xf32>, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a symbolic_shape. In this example, we'll export using tensor<?x3x224x224xf32>.

Symbolic shapes are specified using export.symbolic_shape, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: 2,a * a,5 to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an export.SymbolicScope to avoid unintentional name clashes.

# Construct dynamic sample inputs
dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)

# Export to StableHLO
dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = "[0]"}, tensor<?x512x1x1xf32> {jax.result_info = "[1]"}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32>
    %1 = stablehlo.compare  GE, %0, %c,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
    stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()
    %2:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>)
    return %2#0, %2#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>
  }
  func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = "[0]"}, tensor<?x512x1x1xf32> {jax.result_info = "[1]"}) {
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %c_0 = stablehlo.constant dense<49> : tensor<i32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %c_1 = stablehlo.constant dense<512> : tensor<1xi32>
    %c_2 = stablehlo.constant dense<7> : tensor<1xi32>
    %c_3 = stablehlo.constant dense<256> : tensor<1x 
...
 , tensor<1xi32>) -> tensor<4xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x14x14x256xf32>
    %3 = stablehlo.maximum %arg1, %2 : tensor<?x14x14x256xf32>
    return %3 : tensor<?x14x14x256xf32>
  }
  func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> {
    %c = stablehlo.constant dense<512> : tensor<1xi32>
    %c_0 = stablehlo.constant dense<7> : tensor<1xi32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
    %1 = stablehlo.concatenate %0, %c_0, %c_0, %c, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32>
    %3 = stablehlo.maximum %arg1, %2 : tensor<?x7x7x512xf32>
    return %3 : tensor<?x7x7x512xf32>
  }
}

A few things to note in the exported StableHLO:

  1. The exported program now has tensor<?x3x224x224xf32>. These input types can be refined in many ways: StableHLO has APIs to refine shapes and canonicalize dynamic programs to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example.
  2. JAX will generate guards to ensure the values of a are valid, in this case a > 1 is checked. These can be washed away at compile time once refined.

Export to TensorFlow SavedModel

It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via TensorFlow Serving.

JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section.

Export to SavedModel using jax2tf

JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in jax.experimental.jax2tf. This uses the export function under the hood, so the same jit requirements apply.

Full details on jax2tf can be found in the README. For this example, we'll only need to know the polymorphic_shapes option to specify our dynamic batch dimension.

from jax.experimental import jax2tf
import tensorflow as tf

exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])

# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

!ls /tmp/resnet18_tf
assets         fingerprint.pb saved_model.pb variables

Reload and call the SavedModel

Now we can load that SavedModel and compile using our sample_input from a previous example.

Note: The restored model does not require JAX to run, just XLA.

restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)
Result shape: (1, 512, 7, 7)

Troubleshooting

jax.jit issues

If the function can be JIT'ed, then it can be exported. Ensure jax.jit works first, or look in desired project for uses of JIT already (for example, AlphaFold's apply can be exported easily).

See JAX's JIT compilation documentation and jax.jit API reference and examples for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with static_argnums / static_argnames as in the linked example.

Support tickets

You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!