Tutorial: Exporting StableHLO from JAX

Open in Colab Open in Kaggle

Tutorial Setup

Install required dependencies

We'll be using jax and jaxlib (JAX's XLA package), along with flax and transformers for some models to export.

pip install -U jax jaxlib flax transformers

Define get_stablehlo_asm to help with MLIR printing

Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability.

Export JAX model to StableHLO using jax.export

In this section we'll export a very basic JAX function to StableHLO.

The preferred API for export is jax.experimental.export, which uses jax.jit under the hood. As a rule-of-thumb, a function must be jit-able to be exported to StableHLO.

Export basic JAX model to StableHLO

Let's start by exporting a basic plus function to StableHLO, using np.int32 argument types to trace the function.

Export requires specifying shapes using jax.ShapeDtypeStruct, which are trivial to construct from numpy values.

import jax
from jax.experimental import export
import jax.numpy as jnp
import numpy as np

def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes:
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>
    return %0 : tensor<i32>

Export Huggingface FlaxResNet18 to StableHLO

Now let's look at a simple model that appears in the wild, resnet18.

This example will export a flax model from the huggingface transformers ResNet page, FlaxResNetModel. Much of this steps setup was copied from the huggingface documentation.

The documentation also states: "Finally, this model supports inherent JAX features such as: Just-In-Time (JIT) compilation ..." which means it is perfect for export.

Similar to our very basic example, our steps for export are:

  1. Instantiate a callable (model/function) that can be JIT'ed.
  2. Specify shapes for export using jax.ShapeDtypeStruct on numpy values
  3. Use the JAX export API to get a StableHLO module
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np

# Construct flax model with sample inputs

resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)

# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<1x512x1x1xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>
    %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %3 = stablehlo.constan 
  func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32>
    %2 = stablehlo.maximum %arg0, %1 : tensor<1x7x7x512xf32>
    return %2 : tensor<1x7x7x512xf32>

Export with dynamic batch size

Not let's export that same model with a dynamic batch size!

In the first example we used an input shape of tensor<1x3x224x224xf32>, specifying strict constraints on the input shape. If we wanted to defer the concerete shapes used in compilation until a later point, we can specify a symbolic_shape, in this case we'll export using tensor<?x3x224x224xf32>.

Symbolic shapes are specified using export.symbolic_shape, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: 2,a * a,5 to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an export.SymbolicScope to avoid unintentional name clashes.

dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
dyn_resnet18_export = export.export(resnet18)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<?x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32>
    %1 = stablehlo.constant dense<1> : tensor<i32>
    %2 = stablehlo.compare  GE, %0, %1,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
    stablehlo.custom_call @shape_assertion(%2, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()
    %3:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>)
    return %3#0, %3#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>
  func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>
    %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %3 = stablehl 
    return %10 : tensor<?x14x14x256xf32>
  func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.constant dense<7> : tensor<i32>
    %2 = stablehlo.constant dense<7> : tensor<i32>
    %3 = stablehlo.constant dense<512> : tensor<i32>
    %4 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.constant dense<7> : tensor<1xi32>
    %6 = stablehlo.constant dense<7> : tensor<1xi32>
    %7 = stablehlo.constant dense<512> : tensor<1xi32>
    %8 = stablehlo.concatenate %4, %5, %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %9 = stablehlo.dynamic_broadcast_in_dim %0, %8, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32>
    %10 = stablehlo.maximum %arg1, %9 : tensor<?x7x7x512xf32>
    return %10 : tensor<?x7x7x512xf32>

A few things to note in the exported StableHLO:

  1. The exported program now has tensor<?x3x224x224xf32>. These input types can be refined in many ways: SavedModel execution takes care of refinement which we'll see in the next example, but StableHLO also has APIs to refine shapes and canonicalize dynamic programs to static programs.
  2. JAX will generate guards to ensure the values of a are valid, in this case a > 1 is checked. These can be washed away at compile time once refined.

Export to SavedModel

It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via TF Serving.

JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section we'll be using our dynamic model from the previous section.

Install latest TF

SavedModel definition lives in TF, so we need to install the dependency. We recommend using tensorflow-cpu or tf-nightly.

pip install tensorflow-cpu

Export to SavedModel using jax2tf

JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in jax.experimental.jax2tf. This uses the export function under the hood, so the same jit requirements apply.

Full details on jax2tf can be found in the README, for this example we'll only need to know the polymorphic_shapes option to specify our dynamic batch dimension.

from jax.experimental import jax2tf
import tensorflow as tf

exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])

# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

!ls /tmp/resnet18_tf
assets  fingerprint.pb  saved_model.pb  variables

Reload and call the SavedModel

Now we can load that SavedModel and compile using our sample_input from a previous example.

Note: the restored model does not require JAX to run, just XLA.

restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)
Result shape: (1, 512, 7, 7)

Common Troubleshooting

If the function can be JIT'ed, then it can be exported. Focus on making jax.jit work first, or look in desired project for uses of JIT already (ex: AlphaFold's apply can be exported easily).

See JAX JIT Examples for troubleshooting. The most common issue is control flow, which can often be resolved with static_argnums / static_argnames as in the linked example.

For opening a ticket for help, include a repo using one of the above APIs, this will help get the issue resolved much quicker!