JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.
Tutorial Setup
Install required dependencies
We use jax
and jaxlib
(JAX's support library with compiled binaries), along with flax
and transformers
for some models to export.
We also need to install tensorflow
to work with SavedModel, and recommend using tensorflow-cpu
or tf-nightly
for this tutorial.
pip install -U jax jaxlib flax transformers tensorflow-cpu
Define get_stablehlo_asm
to help with MLIR printing
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir
# Returns prettyprint of StableHLO module without large constants
def get_stablehlo_asm(module_str):
with jax_mlir.make_ir_context():
stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
return stablehlo_module.operation.get_asm(large_elements_limit=20)
# Disable logging for better tutorial rendering
import logging
logging.disable(logging.WARNING)
Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability.
Export JAX model to StableHLO using jax.export
In this section we'll export a basic JAX function and a Flax model to StableHLO.
The preferred API for export is jax.export
. The function to export must be JIT transformed, specifically a result of jax.jit
, to be exported to StableHLO.
Export basic JAX model to StableHLO
Let's start by exporting a basic plus
function to StableHLO, using np.int32
argument types to trace the function.
Export requires specifying shapes using jax.ShapeDtypeStruct
, which can be constructed from NumPy values.
import jax
from jax import export
import jax.numpy as jnp
import numpy as np
# Create a JIT-transformed function
@jax.jit
def plus(x,y):
return jnp.add(x,y)
# Create abstract input shapes
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]
# Export the function to StableHLO
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
print(get_stablehlo_asm(stablehlo_add))
module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) { %0 = stablehlo.add %arg0, %arg1 : tensor<i32> return %0 : tensor<i32> } }
Export Hugging Face FlaxResNet18 to StableHLO
Now let's look at a simple model that appears in the wild, resnet18
.
We'll export a flax
model from the Hugging Face transformers
ResNet page, FlaxResNetModel. This steps setup was copied from the Hugging Face documentation.
The documentation also states: "Finally, this model supports inherent JAX features such as: Just-In-Time (JIT) compilation ..." which means it is perfect for export.
Similar to our very basic example, our steps for export are:
- Instantiate a callable (model/function)
- JIT-transform it with
jax.jit
- Specify shapes for export using
jax.ShapeDtypeStruct
on NumPy values - Use the JAX
export
API to get a StableHLO module
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np
# Construct jit-transformed flax model with sample inputs
resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
resnet18_jit = jax.jit(resnet18)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)
# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x3x224x224xf32>) -> (tensor<1x512x7x7xf32> {jax.result_info = "[0]"}, tensor<1x512x1x1xf32> {jax.result_info = "[1]"}) { %c = stablehlo.constant dense<49> : tensor<i32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32> %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<f32> %cst_2 = stablehlo.constant dense_reso ... func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> { %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32> %1 = stablehlo.maximum %arg0, %0 : tensor<1x7x7x512xf32> return %1 : tensor<1x7x7x512xf32> } }
Export with dynamic batch size
Now let's export that same model with a dynamic batch size!
In the first example, we used an input shape of tensor<1x3x224x224xf32>
, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a symbolic_shape
. In this example, we'll export using tensor<?x3x224x224xf32>
.
Symbolic shapes are specified using export.symbolic_shape
, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: 2,a * a,5
to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an export.SymbolicScope
to avoid unintentional name clashes.
# Construct dynamic sample inputs
dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
# Export to StableHLO
dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = "[0]"}, tensor<?x512x1x1xf32> {jax.result_info = "[1]"}) { %c = stablehlo.constant dense<1> : tensor<i32> %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32> %1 = stablehlo.compare GE, %0, %c, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> () %2:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>) return %2#0, %2#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32> } func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = "[0]"}, tensor<?x512x1x1xf32> {jax.result_info = "[1]"}) { %c = stablehlo.constant dense<1> : tensor<1xi32> %c_0 = stablehlo.constant dense<49> : tensor<i32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> %c_1 = stablehlo.constant dense<512> : tensor<1xi32> %c_2 = stablehlo.constant dense<7> : tensor<1xi32> %c_3 = stablehlo.constant dense<256> : tensor<1x ... , tensor<1xi32>) -> tensor<4xi32> %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x14x14x256xf32> %3 = stablehlo.maximum %arg1, %2 : tensor<?x14x14x256xf32> return %3 : tensor<?x14x14x256xf32> } func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> { %c = stablehlo.constant dense<512> : tensor<1xi32> %c_0 = stablehlo.constant dense<7> : tensor<1xi32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32> %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32> %1 = stablehlo.concatenate %0, %c_0, %c_0, %c, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32> %3 = stablehlo.maximum %arg1, %2 : tensor<?x7x7x512xf32> return %3 : tensor<?x7x7x512xf32> } }
A few things to note in the exported StableHLO:
- The exported program now has
tensor<?x3x224x224xf32>
. These input types can be refined in many ways: StableHLO has APIs to refine shapes and canonicalize dynamic programs to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example. - JAX will generate guards to ensure the values of
a
are valid, in this casea > 1
is checked. These can be washed away at compile time once refined.
Export to TensorFlow SavedModel
It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via TensorFlow Serving.
JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section.
Export to SavedModel using jax2tf
JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in jax.experimental.jax2tf
. This uses the export
function under the hood, so the same jit
requirements apply.
Full details on jax2tf
can be found in the README. For this example, we'll only need to know the polymorphic_shapes
option to specify our dynamic batch dimension.
from jax.experimental import jax2tf
import tensorflow as tf
exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])
# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
!ls /tmp/resnet18_tf
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using our sample_input
from a previous example.
Note: The restored model does not require JAX to run, just XLA.
restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)
Result shape: (1, 512, 7, 7)
Troubleshooting
jax.jit
issues
If the function can be JIT'ed, then it can be exported. Ensure jax.jit
works first, or look in desired project for uses of JIT already (for example, AlphaFold's apply
can be exported easily).
See JAX's JIT compilation documentation and jax.jit
API reference and examples for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with static_argnums
/ static_argnames
as in the linked example.
Support tickets
You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!