PyTorch is a popular library for building deep learning models. In this tutorial, you will learn to export a PyTorch model to StableHLO, and then directly to TensorFlow SavedModel.
Tutorial Setup
Install required dependencies
We use torch
and torchvision
to get a ResNet18 model model, and torch_xla
to export it to StableHLO.
We also need to install tensorflow
to work with SavedModel, and recommend using tensorflow-cpu
or tf-nightly
for this tutorial.
pip install torch_xla==2.5.0 torch==2.5.0 torchvision==0.20.0 tensorflow-cpu
Export PyTorch model to StableHLO
The general set of steps for exporting a PyTorch model to StableHLO is:
- Use PyTorch's
torch.export
API to generate an exported FX graph (i.e.,ExportedProgram
) - Use PyTorch/XLA's
torch_xla.stablehlo
API to convert theExportedProgram
to StableHLO
Export model to FX graph using torch.export
This step uses vanilla PyTorch APIs to export a resnet18
model from torchvision
.
Sample inputs are required for graph tracing, we use a tensor<4x3x224x224xf32>
in this case.
import torch
import torchvision
from torch.export import export
resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
sample_input = (torch.randn(4, 3, 224, 224), )
exported = export(resnet18, sample_input)
Export FX graph to StableHLO using torch_xla.stablehlo
Once we have an exported FX graph, we can convert it to StableHLO using exported_program_to_stablehlo
in the torch_xla.stablehlo
module.
We can then look at the exported StableHLO program with get_stablehlo_text
.
from torch_xla.stablehlo import exported_program_to_stablehlo
stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward')[0:4000],"\n...")
WARNING:root:Defaulting to PJRT_DEVICE=CPU module @IrToHlo.484 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> { %cst = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32> %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32> %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32> %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32> %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32> %cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32> %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %0 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32> %output, %batch_mean, %batch_var = "stablehlo.ba ...
Tip:
Dynamic batch dimensions can be specified as a part of the initial torch.export
step.
torch_xla
's support for exporting dynamic models is limited, for these cases we recommend using torch_xla2
for this. This lowering path leverages JAX for lowering to StableHLO, and has high opset coverage with much broader support for exported programs with dynamic shapes.
Save and reload StableHLO
StableHLOGraphModule
has methods to save
and load
StableHLO artifacts.
This stores StableHLO portable bytecode artifacts which have complete forward and backward compatibility guarantees.
from torch_xla.stablehlo import StableHLOGraphModule
# Save to tmp
stablehlo_program.save('/tmp/stablehlo_dir')
!ls /tmp/stablehlo_dir
!ls /tmp/stablehlo_dir/functions
constants data functions forward.bytecode forward.meta forward.mlir
# Reload and execute - Stable serialization, forward / backward compatible.
reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')
print(reloaded(sample_input[0]))
tensor([[-2.3258, -0.9606, -0.9439, ..., 0.3519, 0.6261, 2.3971], [ 1.6479, -0.0268, 1.0511, ..., -1.2512, 2.2042, 1.8865], [ 0.1756, -0.3658, -0.0651, ..., 0.0661, 2.1358, 0.5009], [-1.6709, -0.7363, -2.0963, ..., -1.3716, 0.3321, -0.9199]], device='xla:0')
Note: You can also use convenience wrappers like save_torch_model_as_stablehlo
to export and save. Learn more in the PyTorch/XLA documentation on exporting to StableHLO.
Export to TensorFlow SavedModel
It is common to want to export a StableHLO model to TensorFlow SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via TensorFlow Serving.
PyTorch/XLA's torch_xla.tf_saved_model_integration
module makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.
Export to SavedModel with torch_xla.tf_saved_model_integration
We use the save_torch_module_as_tf_saved_model
function for this conversion, which uses the torch.export
and torch_xla.stablehlo.exported_program_to_stablehlo
functions under the hood.
The input to the API is a PyTorch model, and we use the same resnet18
from the previous examples.
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
save_torch_module_as_tf_saved_model(
resnet18, # original pytorch torch.nn.Module
sample_input, # sample inputs used to trace
'/tmp/resnet_tf' # directory for tf.saved_model
)
!ls /tmp/resnet_tf/
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using our sample_input
from a previous example.
Note: The restored model does not require PyTorch or PyTorch/XLA to run, just XLA.
import tensorflow as tf
loaded_m = tf.saved_model.load('/tmp/resnet_tf')
print(loaded_m.f(tf.constant(sample_input[0].numpy())))
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1730760467.760638 8492 service.cc:148] XLA service 0x7ede002016e0 initialized for platform Host (this does not guarantee that XLA will be used). Devices: I0000 00:00:1730760467.760777 8492 service.cc:156] StreamExecutor device (0): Host, Default Version I0000 00:00:1730760468.613723 8492 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. [<tf.Tensor: shape=(4, 1000), dtype=float32, numpy= array([[-2.3257551 , -0.96061766, -0.9439326 , ..., 0.35189423, 0.62605226, 2.3971176 ], [ 1.6479174 , -0.02676968, 1.0511047 , ..., -1.2511721 , 2.2041895 , 1.8865337 ], [ 0.17559683, -0.365776 , -0.06507193, ..., 0.06606296, 2.135755 , 0.500913 ], [-1.6709077 , -0.7362997 , -2.0962732 , ..., -1.3716122 , 0.33205754, -0.91991633]], dtype=float32)>]
Troubleshooting
Version mismatch
Ensure that you have the same version of PyTorch/XLA and PyTorch. Version mismatch can result in import errors, as well as some runtime issues.
Export bugs
If your program fails to export due to a bug in the PyTorch/XLA bridge, open an issue on GitHub with a reproducible example:
- Issues in
torch.export
: Report these in the upstream pytorch/pytorch repository - Issues in
torch_xla.stablehlo
: Open a ticket on pytorch/xla repository