Tutorial Setup
Install required dependencies
We'll be using jax
and jaxlib
(JAX's XLA package), along with flax
and transformers
for some models to export.
pip install -U jax jaxlib flax transformers
Define get_stablehlo_asm
to help with MLIR printing
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir
# Returns prettyprint of StableHLO module without large constants
def get_stablehlo_asm(module_str):
with jax_mlir.make_ir_context():
stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
return stablehlo_module.operation.get_asm(large_elements_limit=20)
# Disable logging for better tutorial rendering
import logging
logging.disable(logging.WARNING)
Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability.
Export JAX model to StableHLO using jax.export
In this section we'll export a very basic JAX function to StableHLO.
The preferred API for export is jax.experimental.export
, which uses jax.jit
under the hood. As a rule-of-thumb, a function must be jit
-able to be exported to StableHLO.
Export basic JAX model to StableHLO
Let's start by exporting a basic plus
function to StableHLO, using np.int32
argument types to trace the function.
Export requires specifying shapes using jax.ShapeDtypeStruct
, which are trivial to construct from numpy values.
import jax
from jax.experimental import export
import jax.numpy as jnp
import numpy as np
def plus(x,y):
return jnp.add(x,y)
# Create abstract input shapes:
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
print(get_stablehlo_asm(stablehlo_add))
module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.add %arg0, %arg1 : tensor<i32> return %0 : tensor<i32> } }
Export Huggingface FlaxResNet18 to StableHLO
Now let's look at a simple model that appears in the wild, resnet18
.
This example will export a flax
model from the huggingface transformers
ResNet page, FlaxResNetModel. Much of this steps setup was copied from the huggingface documentation.
The documentation also states: "Finally, this model supports inherent JAX features such as: Just-In-Time (JIT) compilation ..." which means it is perfect for export.
Similar to our very basic example, our steps for export are:
- Instantiate a callable (model/function) that can be JIT'ed.
- Specify shapes for export using
jax.ShapeDtypeStruct
on numpy values - Use the JAX
export
API to get a StableHLO module
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np
# Construct flax model with sample inputs
resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)
# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<1x512x1x1xf32> {mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32> %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32> %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32> %3 = stablehlo.constan ... } func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> { %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32> %2 = stablehlo.maximum %arg0, %1 : tensor<1x7x7x512xf32> return %2 : tensor<1x7x7x512xf32> } }
Export with dynamic batch size
Not let's export that same model with a dynamic batch size!
In the first example we used an input shape of tensor<1x3x224x224xf32>
, specifying strict constraints on the input shape. If we wanted to defer the concerete shapes used in compilation until a later point, we can specify a symbolic_shape
, in this case we'll export using tensor<?x3x224x224xf32>
.
Symbolic shapes are specified using export.symbolic_shape
, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: 2,a * a,5
to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an export.SymbolicScope
to avoid unintentional name clashes.
dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
dyn_resnet18_export = export.export(resnet18)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<?x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = "default"}) { %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32> %1 = stablehlo.constant dense<1> : tensor<i32> %2 = stablehlo.compare GE, %0, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> stablehlo.custom_call @shape_assertion(%2, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> () %3:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>) return %3#0, %3#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32> } func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = "default"}) { %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32> %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32> %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32> %3 = stablehl ... <?x14x14x256xf32> return %10 : tensor<?x14x14x256xf32> } func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> { %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %1 = stablehlo.constant dense<7> : tensor<i32> %2 = stablehlo.constant dense<7> : tensor<i32> %3 = stablehlo.constant dense<512> : tensor<i32> %4 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32> %5 = stablehlo.constant dense<7> : tensor<1xi32> %6 = stablehlo.constant dense<7> : tensor<1xi32> %7 = stablehlo.constant dense<512> : tensor<1xi32> %8 = stablehlo.concatenate %4, %5, %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> %9 = stablehlo.dynamic_broadcast_in_dim %0, %8, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32> %10 = stablehlo.maximum %arg1, %9 : tensor<?x7x7x512xf32> return %10 : tensor<?x7x7x512xf32> } }
A few things to note in the exported StableHLO:
- The exported program now has
tensor<?x3x224x224xf32>
. These input types can be refined in many ways: SavedModel execution takes care of refinement which we'll see in the next example, but StableHLO also has APIs to refine shapes and canonicalize dynamic programs to static programs. - JAX will generate guards to ensure the values of
a
are valid, in this casea > 1
is checked. These can be washed away at compile time once refined.
Export to SavedModel
It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via TF Serving.
JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section we'll be using our dynamic model from the previous section.
Install latest TF
SavedModel definition lives in TF, so we need to install the dependency. We recommend using tensorflow-cpu
or tf-nightly
.
pip install tensorflow-cpu
Export to SavedModel using jax2tf
JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in jax.experimental.jax2tf
. This uses the export
function under the hood, so the same jit
requirements apply.
Full details on jax2tf
can be found in the README, for this example we'll only need to know the polymorphic_shapes
option to specify our dynamic batch dimension.
from jax.experimental import jax2tf
import tensorflow as tf
exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])
# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
!ls /tmp/resnet18_tf
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using our sample_input
from a previous example.
Note: the restored model does not require JAX to run, just XLA.
restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)
Result shape: (1, 512, 7, 7)
Common Troubleshooting
If the function can be JIT'ed, then it can be exported. Focus on making jax.jit
work first, or look in desired project for uses of JIT already (ex: AlphaFold's apply
can be exported easily).
See JAX JIT Examples for troubleshooting. The most common issue is control flow, which can often be resolved with static_argnums
/ static_argnames
as in the linked example.
For opening a ticket for help, include a repo using one of the above APIs, this will help get the issue resolved much quicker!