Intro to the torch_xla.stablehlo
module.
Tutorial Setup
Install required dependencies
We'll be using torch
and torchvision
to get a resnet18
model, and torch_xla
to export to StableHLO.
pip install torch_xla==2.3.0 torch==2.3.0 torchvision==0.18.0
Export PyTorch model to StableHLO
The general set of steps for exporting a PyTorch model to StableHLO is:
- Export using PyTorch's
torch.export
API. - Convert exported FX Graph to StableHLO using
torch_xla.stablehlo
APIs.
Export model to FX graph using torch.export
This step uses entirely vanilla PyTorch APIs to export a resnet18
model from torchvision
. Sample inputs are required for graph tracing, we use a tensor<4x3x224x224xf32>
in this case.
import torch
import torchvision
from torch.export import export
resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
sample_input = (torch.randn(4, 3, 224, 224), )
exported = export(resnet18, sample_input)
Export FX Graph to StableHLO using TorchXLA
Once we have an exported FX graph, we can convert to StableHLO using the torch_xla.stablehlo
module. In this case we'll use exported_program_to_stablehlo
.
from torch_xla.stablehlo import exported_program_to_stablehlo
stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward')[0:4000],"\n...")
module @IrToHlo.508 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<1000xf32>, %arg1: tensor<1000x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512xf32>, %arg4: tensor<512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512x256x1x1xf32>, %arg7: tensor<256xf32>, %arg8: tensor<256xf32>, %arg9: tensor<256xf32>, %arg10: tensor<256xf32>, %arg11: tensor<256x128x1x1xf32>, %arg12: tensor<128xf32>, %arg13: tensor<128xf32>, %arg14: tensor<128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x64x1x1xf32>, %arg17: tensor<64xf32>, %arg18: tensor<64xf32>, %arg19: tensor<64xf32>, %arg20: tensor<64xf32>, %arg21: tensor<64x3x7x7xf32>, %arg22: tensor<4x3x224x224xf32>, %arg23: tensor<64xf32>, %arg24: tensor<64xf32>, %arg25: tensor<64xf32>, %arg26: tensor<64xf32>, %arg27: tensor<64x64x3x3xf32>, %arg28: tensor<64xf32>, %arg29: tensor<64xf32>, %arg30: tensor<64xf32>, %arg31: tensor<64xf32>, %arg32: tensor<64x64x3x3xf32>, %arg33: tensor<64xf32>, %arg34: tensor<64xf32>, %arg35: tensor<64xf32>, %arg36: tensor<64xf32>, %arg37: tensor<64x64x3x3xf32>, %arg38: tensor<64xf32>, %arg39: tensor<64xf32>, %arg40: tensor<64xf32>, %arg41: tensor<64xf32>, %arg42: tensor<64x64x3x3xf32>, %arg43: tensor<128xf32>, %arg44: tensor<128xf32>, %arg45: tensor<128xf32>, %arg46: tensor<128xf32>, %arg47: tensor<128x128x3x3xf32>, %arg48: tensor<128xf32>, %arg49: tensor<128xf32>, %arg50: tensor<128xf32>, %arg51: tensor<128xf32>, %arg52: tensor<128x64x3x3xf32>, %arg53: tensor<128xf32>, %arg54: tensor<128xf32>, %arg55: tensor<128xf32>, %arg56: tensor<128xf32>, %arg57: tensor<128x128x3x3xf32>, %arg58: tensor<128xf32>, %arg59: tensor<128xf32>, %arg60: tensor<128xf32>, %arg61: tensor<128xf32>, %arg62: tensor<128x128x3x3xf32>, %arg63: tensor<256xf32>, %arg64: tensor<256xf32>, %arg65: tensor<256xf32>, %arg66: tensor<256xf32>, %arg67: tensor<256x256x3x3xf32>, %arg68: tensor<256xf32>, %arg69: tensor<256xf32>, %arg70: tensor<256xf32>, %arg71: tensor<256xf32>, %arg72: tensor<256x128x3x3xf32>, %arg73: tensor<256xf32>, %arg74: tensor<256xf32>, %arg75: tensor<256xf32>, %arg76: tensor<256xf32>, %arg77: tensor<256x256x3x3xf32>, %arg78: tensor<256xf32>, %arg79: tensor<256xf32>, %arg80: tensor<256xf32>, %arg81: tensor<256xf32>, %arg82: tensor<256x256x3x3xf32>, %arg83: tensor<512xf32>, %arg84: tensor<512xf32>, %arg85: tensor<512xf32>, %arg86: tensor<512xf32>, %arg87: tensor<512x512x3x3xf32>, %arg88: tensor<512xf32>, %arg89: tensor<512xf32>, %arg90: tensor<512xf32>, %arg91: tensor<512xf32>, %arg92: tensor<512x256x3x3xf32>, %arg93: tensor<512xf32>, %arg94: tensor<512xf32>, %arg95: tensor<512xf32>, %arg96: tensor<512xf32>, %arg97: tensor<512x512x3x3xf32>, %arg98: tensor<512xf32>, %arg99: tensor<512xf32>, %arg100: tensor<512xf32>, %arg101: tensor<512xf32>, %arg102: tensor<512x512x3x3xf32>) -> tensor<4x1000xf32> { %0 = stablehlo.constant dense<0.0204081628> : tensor<4x512xf32> %1 = stablehlo.constant dense<0.000000e+00> : tensor<4x512x7x7xf32> %2 = stablehlo.constant dense<0.000000e+00> : tensor<4x256x14x14xf32> %3 = stablehlo.constant dense<0.000000e+00> : tensor<4x128x28x28xf32> %4 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x56x56xf32> %5 = stablehlo.constant dense<0.000000e+00> : tensor<4x64x112x112xf32> %6 = stablehlo.constant dense<0xFF800000> : tensor<f32> %7 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %8 = stablehlo.convolution(%arg22, %arg21) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[3, 3], [3, 3]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x3x224x224xf32>, tensor<64x3x7x7xf32>) -> tensor<4x64x112x112xf32> %output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%8, %arg20, %arg19) {epsilon = 9.99999974E-6 : f3 ...
Export with dynamic batch dimension
This is a new feature and will be included in pip installation after the 2.4 release, or if using torch_xla
nightly. Once PyTorch/XLA 2.4 is released, this will be converted into a running example. Using the nightly torch
and torch_xla
will likely lead to notebook failures in the meantime.
Dynamic batch dimensions can be specified as a part of the initial torch.export
step. The FX Graph's symint information is used to export to dynamic StableHLO.
In this example, we specify that dim 0 of the sample input is dynamic, which propagates shape using a tensor<?x3x224x224xf32>
.
from torch.export import Dim
# Create a dynamic batch size, for the first dimension of the input
batch = Dim("batch", max=15)
dynamic_shapes = ({0: batch},)
dynamic_export = export(resnet18, sample_input, dynamic_shapes=dynamic_shapes)
dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)
print(dynamic_stablehlo.get_stablehlo_text('forward')[0:5000],"\n...")
Saving and reloading StableHLO
The StableHLOGraphModule
has methods to save
and load
StableHLO artifacts. This stores StableHLO portable bytecode artifacts which have full backward compatiblity guarantees.
from torch_xla.stablehlo import StableHLOGraphModule
# Save to tmp
stablehlo_program.save('/tmp/stablehlo_dir')
!ls /tmp/stablehlo_dir
!ls /tmp/stablehlo_dir/functions
# Reload and execute - Stable serialization, forward / backward compatible.
reloaded = StableHLOGraphModule.load('/tmp/stablehlo_dir')
print(reloaded(sample_input[0]))
constants data functions forward.bytecode forward.meta forward.mlir tensor([[-0.7431, -2.5955, -0.0718, ..., -1.4230, 1.5928, -0.7693], [ 0.6199, -1.5941, -0.9018, ..., 0.2452, 0.6159, 2.4765], [-3.0291, 2.2174, -2.2809, ..., -0.9081, 1.8253, 2.2141], [ 0.9318, -0.0566, 0.8561, ..., -0.1650, 0.7882, -0.2697]], device='xla:0')
Export to SavedModel
It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via TF Serving.
PyTorch/XLA makes it easy to pack StableHLO into a SavedModel, which can be loaded back and executed.
Install latest TF
SavedModel definition lives in TF, so we need to install the dependency. We recommend using tensorflow-cpu
or tf-nightly
.
pip install -U tensorflow-cpu
Export to SavedModel using torch_xla.tf_saved_model_integration
PyTorch/XLA provides a simple API for exporting StableHLO in a SavedModel save_torch_module_as_tf_saved_model
. This uses the torch.export
and torch_xla.stablehlo.exported_program_to_stablehlo
functions under the hood.
The input to the API is a PyTorch model, we'll use the same resnet18
from the previous examples.
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
save_torch_module_as_tf_saved_model(
resnet18, # original pytorch torch.nn.Module
sample_input, # sample inputs used to trace
'/tmp/resnet_tf' # directory for tf.saved_model
)
!ls /tmp/resnet_tf/
assets fingerprint.pb saved_model.pb variables
Reload and call the SavedModel
Now we can load that SavedModel and compile using our sample_input
from a previous example.
Note: the restored model does not require PyTorch or PyTorch/XLA to run, just XLA.
import tensorflow as tf
loaded_m = tf.saved_model.load('/tmp/resnet_tf')
print(loaded_m.f(tf.constant(sample_input[0].numpy())))
[<tf.Tensor: shape=(4, 1000), dtype=float32, numpy= array([[-0.74313045, -2.595488 , -0.0718156 , ..., -1.4230162 , 1.5928375 , -0.7693139 ], [ 0.61994374, -1.594082 , -0.901797 , ..., 0.2451565 , 0.6159245 , 2.4764667 ], [-3.029084 , 2.2174084 , -2.2808676 , ..., -0.90810233, 1.8252805 , 2.214109 ], [ 0.93176216, -0.0566061 , 0.8560745 , ..., -0.16496754, 0.7881946 , -0.26973075]], dtype=float32)>]
Common Troubleshooting
Most issues in PyTorch to StableHLO require a GH ticket as a next step. Teams are generally quick to help resolve issues.
- Issues in
torch.export
: These need to be resolved in upstream PyTorch. - Issues in
torch_xla.stablehlo
: Open a ticket on pytorch/xla repo.
The most common issue is that dependencies are out of sync. PyTorch/XLA and PyTorch must be on the same version. Version mismatch can result in import errors, as well as some runtime issues. After that, it is possible that a program fails to export due to a bug in the PyTorch/XLA bridge, for which a ticket with repro is helpful.