Tutorial: Exporting StableHLO from PyTorch

Open in Colab Open in Kaggle

PyTorch is a popular library for building deep learning models. In this tutorial, you will learn to export a PyTorch model to StableHLO, and then directly to TensorFlow SavedModel.

Tutorial Setup

Install required dependencies

We use torch and torchvision to get a ResNet18 model model, and torch_xla to export it to StableHLO. We also need to install tensorflow to work with SavedModel, and recommend using tensorflow-cpu or tf-nightly for this tutorial.

pip install torch_xla==2.5.0 torch==2.5.0 torchvision==0.20.0 tensorflow-cpu

Export PyTorch model to StableHLO

The general set of steps for exporting a PyTorch model to StableHLO is:

  1. Use PyTorch's torch.export API to generate an exported FX graph (i.e., ExportedProgram)
  2. Use PyTorch/XLA's torch_xla.stablehlo API to convert the ExportedProgram to StableHLO

Export model to FX graph using torch.export

This step uses vanilla PyTorch APIs to export a resnet18 model from torchvision. Sample inputs are required for graph tracing, we use a tensor<4x3x224x224xf32> in this case.

import torch
import torchvision
from torch.export import export

resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
sample_input = (torch.randn(4, 3, 224, 224), )
exported = export(resnet18, sample_input)

Export FX graph to StableHLO using torch_xla.stablehlo

Once we have an exported FX graph, we can convert it to StableHLO using exported_program_to_stablehlo in the torch_xla.stablehlo module.

We can then look at the exported StableHLO program with get_stablehlo_text.

from torch_xla.stablehlo import exported_program_to_stablehlo

stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward')[0:4000],"\n...")
WARNING:root:Defaulting to PJRT_DEVICE=CPU
module @IrToHlo.484 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> {
    %cst = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32>
    %cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32>
    %output, %batch_mean, %batch_var = "stablehlo.ba 
...

Tip:

Dynamic batch dimensions can be specified as a part of the initial torch.export step.

torch_xla's support for exporting dynamic models is limited, for these cases we recommend using torch_xla2 for this. This lowering path leverages JAX for lowering to StableHLO, and has high opset coverage with much broader support for exported programs with dynamic shapes.

Save and reload StableHLO

StableHLOGraphModule has methods to save and load StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have complete forward and backward compatibility guarantees.

from torch_xla.stablehlo import StableHLOGraphModule

# Save to tmp
stablehlo_program.save('/tmp/stablehlo_dir')
!ls /tmp/stablehlo_dir
!ls /tmp/stablehlo_dir/functions
constants  data  functions
forward.bytecode  forward.meta  forward.mlir
# Reload and execute - Stable serialization, forward / backward compatible.
reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')
print(reloaded(sample_input[0]))
tensor([[-2.3258, -0.9606, -0.9439,  ...,  0.3519,  0.6261,  2.3971],
        [ 1.6479, -0.0268,  1.0511,  ..., -1.2512,  2.2042,  1.8865],
        [ 0.1756, -0.3658, -0.0651,  ...,  0.0661,  2.1358,  0.5009],
        [-1.6709, -0.7363, -2.0963,  ..., -1.3716,  0.3321, -0.9199]],
       device='xla:0')

Note: You can also use convenience wrappers like save_torch_model_as_stablehlo to export and save. Learn more in the PyTorch/XLA documentation on exporting to StableHLO.

Export to TensorFlow SavedModel

It is common to want to export a StableHLO model to TensorFlow SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via TensorFlow Serving.

PyTorch/XLA's torch_xla.tf_saved_model_integration module makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.

Export to SavedModel with torch_xla.tf_saved_model_integration

We use the save_torch_module_as_tf_saved_model function for this conversion, which uses the torch.export and torch_xla.stablehlo.exported_program_to_stablehlo functions under the hood.

The input to the API is a PyTorch model, and we use the same resnet18 from the previous examples.

from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model

save_torch_module_as_tf_saved_model(
    resnet18,         # original pytorch torch.nn.Module
    sample_input,     # sample inputs used to trace
    '/tmp/resnet_tf'  # directory for tf.saved_model
)

!ls /tmp/resnet_tf/
assets  fingerprint.pb  saved_model.pb  variables

Reload and call the SavedModel

Now we can load that SavedModel and compile using our sample_input from a previous example.

Note: The restored model does not require PyTorch or PyTorch/XLA to run, just XLA.

import tensorflow as tf

loaded_m = tf.saved_model.load('/tmp/resnet_tf')
print(loaded_m.f(tf.constant(sample_input[0].numpy())))
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1730760467.760638    8492 service.cc:148] XLA service 0x7ede002016e0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1730760467.760777    8492 service.cc:156]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1730760468.613723    8492 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=
array([[-2.3257551 , -0.96061766, -0.9439326 , ...,  0.35189423,
         0.62605226,  2.3971176 ],
       [ 1.6479174 , -0.02676968,  1.0511047 , ..., -1.2511721 ,
         2.2041895 ,  1.8865337 ],
       [ 0.17559683, -0.365776  , -0.06507193, ...,  0.06606296,
         2.135755  ,  0.500913  ],
       [-1.6709077 , -0.7362997 , -2.0962732 , ..., -1.3716122 ,
         0.33205754, -0.91991633]], dtype=float32)>]

Troubleshooting

Version mismatch

Ensure that you have the same version of PyTorch/XLA and PyTorch. Version mismatch can result in import errors, as well as some runtime issues.

Export bugs

If your program fails to export due to a bug in the PyTorch/XLA bridge, open an issue on GitHub with a reproducible example:

  • Issues in torch.export: Report these in the upstream pytorch/pytorch repository
  • Issues in torch_xla.stablehlo: Open a ticket on pytorch/xla repository