Shardy is a new propagation system being introduced into the XLA stack, and below we want to introduce any JAX users to:
- What has changed in JAX
- Why Shardy?
- Future plans
This is meant for JAX users who use jax.jit
for running training/inference models across more than 1 GPU or TPU (batch parallelism, megatron, ZeRO, etc). They would be using things like PartitionSpec
s and NamedSharding
s.
1. What has changed in JAX?
State of JAX before: GSPMD
Prior to Shardy, JAX users who partitioned their models across models across multiple devices used GSPMD behind the scenes.
GSPMD is the propagation+partitioning system that lives in the middle of the XLA pipeline. It operates on HLO - the IR that comes after StableHLO (the program you get after running jax.jit.lower
).
JAX doesn't run GSPMD directly, but encodes instructions into the StableHLO IR for GSPMD to read later on.
But before we go any further, let's introduce our working example.
Make sure you have installed jax>=0.4.35
.
pip install jax==0.4.35
Imports and utilities
import os
# make sure our code runs on 8 devices
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental.shard_map import shard_map
First, let's create our mesh.
mesh = Mesh(
np.reshape(np.array(jax.devices()), (4, 2)),
('data', 'model'))
print(mesh.shape)
OrderedDict([('data', 4), ('model', 2)])
In/Out shardings
Let's look at what changed the most: how sharding attributes are encoded in the JAX program for the compiler to read.
Let's look at it through an example. It's going to be an MLP-like model consisting of no bias tensors, and 2 layers (two matmuls).
def predict(x, w1, w2):
x = jnp.tanh(x)
z1 = jnp.einsum('ij,jk->ik', x, w1)
z2 = jnp.einsum('ij,jk->ik', z1, w2)
return jnp.sin(z2)
What we will want to do here sharding wise is:
data
parallelism on xtensor
parallelism onw1
andw2
through the megatron sharding strategy.
Now let's prepare the model for GSPMD sharding. Note that we will explicitly shard w1
, but let GSPMD propagation shard w2
.
def run_in_out_shardings():
samples = jax.ShapeDtypeStruct((16, 128), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec('data', None)))
samples_sharding = NamedSharding(mesh, PartitionSpec('data', None))
w1 = jax.ShapeDtypeStruct((128, 256), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec(None, 'model')))
w1_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
w2 = jax.ShapeDtypeStruct((256, 10), jnp.float32)
w2_sharding = None
print(jax.jit(predict, in_shardings=(samples_sharding, w1_sharding, w2_sharding)).lower(samples, w1, w2).as_text())
run_in_out_shardings()
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<16x128xf32> {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"}, %arg1: tensor<128x256xf32> {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) { %0 = stablehlo.tanh %arg0 : tensor<16x128xf32> %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32> %3 = stablehlo.sine %2 : tensor<16x10xf32> return %3 : tensor<16x10xf32> } }
GSPMD's sharding annotations look like the following:
JAX sharding | GSPMD sharding |
---|---|
NamedSharding(mesh, PartitionSpec('data', None)) |
{devices=[4,1,2]<=[8] last_tile_dim_replicate} |
NamedSharding(mesh, PartitionSpec(None, 'model')) |
{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate} |
None |
nothing |
None
is no sharding as expected since GSPMD will populate this during sharding propagation.
Notice how all the axis names go away? While there is a 1:1 correspondance between NamedSharding
and GSPMD sharding, as a reader, it can be difficult to read. It is only more difficult once you introduce various axis names.
Let's look at Shardy for comparison. To enable Shardy in JAX, simply enable the flag:
jax.config.update("jax_use_shardy_partitioner", True)
run_in_out_shardings()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<16x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]>}, %arg1: tensor<128x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"model"}]>}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) { %0 = stablehlo.tanh %arg0 : tensor<16x128xf32> %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32> %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32> %3 = stablehlo.sine %2 : tensor<16x10xf32> return %3 : tensor<16x10xf32> } }
Now we have
JAX sharding | Shardy sharding |
---|---|
NamedSharding(mesh, PartitionSpec('data', None)) |
#sdy.sharding<@mesh, [{"data"}, {}]> |
NamedSharding(mesh, PartitionSpec(None, 'model')) |
#sdy.sharding<@mesh, [{}, {"model"}]> |
None |
nothing |
Shardy's representation is a lot closer to what JAX NamedSharding
s are like. So when looking at a file dump of your program after propagation, it will be a lot easier to understand what is going on since the correspondance is a lot closer to JAX.
Note that instead of the total devices/axes living on the sharding, they live on a top level @mesh
value.
jax.lax.with_sharding_constraint
GSPMD currently lowers it to a custom call:
def run_with_sharding_constraint():
x = jax.ShapeDtypeStruct((32, 64), jnp.float32)
def f(x):
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, PartitionSpec('data', PartitionSpec.UNCONSTRAINED)))
print(jax.jit(f).lower(x).as_text())
run_with_sharding_constraint()
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "unspecified_dims=[1]", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
But under Shardy it's an explicit op:
jax.config.update("jax_use_shardy_partitioner", True)
run_with_sharding_constraint()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = sdy.sharding_constraint %arg0 <@mesh, [{"data"}, {?}]> : tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
Note that UNCONSTRAINED
under GSPMD has the custom call have an op attribute backend_config = "unspecified_dims=[1]"
. But under Shardy, it makes dim 1 be {?}
. In Shardy, dimension shardings without a ?
are closed, meaning that dimension can't be further sharded, but when it has a trailing ?
, it can be further sharded. Refer to Sharding representation for more info on the sharding representation.
jax.experimental.shard_map
Under GSPMD this is a few different custom calls with various shard_map
specific attributes on the GSPMD sharding. Let's look where the model
axis is auto
, meaning it's free to be used inside the body of the shard_map by sharding constraints.
def run_shard_map():
x = jax.ShapeDtypeStruct((32, 64), jnp.float32)
def body(x):
return jax.lax.all_gather(x, 'data', tiled=True)
shmaped_f = shard_map(
body,
mesh=mesh,
in_specs=(jax.sharding.PartitionSpec('data',),),
out_specs=jax.sharding.PartitionSpec(),
check_rep=False)
print(jax.jit(shmaped_f).lower(x).as_text())
print(run_shard_map())
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<8x64xf32> %2 = call @shmap_body(%1) : (tensor<8x64xf32>) -> tensor<32x64xf32> %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<32x64xf32>) -> tensor<32x64xf32> return %4 : tensor<32x64xf32> } func.func private @shmap_body(%arg0: tensor<8x64xf32>) -> (tensor<32x64xf32> {jax.result_info = "[None, None]"}) { %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } } None
With the custom calls and GSPMD sharding, it's getting pretty confusing. Let's look at what Shardy gives:
jax.config.update("jax_use_shardy_partitioner", True)
run_shard_map()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["data"=4, "model"=2]> func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"data"}, {}]>] out_shardings=[<@mesh, [{}, {}]>] manual_axes={"data", "model"} (%arg1: tensor<8x64xf32>) { %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32> sdy.return %1 : tensor<32x64xf32> } : (tensor<32x64xf32>) -> tensor<32x64xf32> return %0 : tensor<32x64xf32> } }
We now:
- Have a single op called
sdy.manual_computation
which holds:- the
in_specs
- the
out_specs
- the body of the shard_map
- the inverse of the
auto
axes which we callmanual_axes
- the
A lot easier to read!
Auto partitioners
In progress.
XLA_DUMP_TO
When specifying the XLA_DUMP_TO
, you will see an additional shardy/
directory containing various dumps of the StableHLO program. A lot of them are currently only relevant to the Shardy team to debug issues. The one you should focus on when debugging is sdy_module_after_sdy_export.mlir
which is the module after propagation finished on the StableHLO program.
2. Why Shardy?
Readability
As seen above, it's much easier to read the shardings and shard_maps and understand how they match what is happening in the JAX code. Similarly GSPMD propagation will give back HLO code - not MLIR which both Shardy and jax.jit.lower
return.
Interpretability
We are planning on exposing a feature we call "user priorities" (not in JAX yet!). It allows you to attach a value telling Shardy how important a tensor's dimension sharding is over other constraints in the program.
Higher prioritied are defines as lower values (lowest being 0, think of it as a p0 priority tasks).
PartitionSpec(None, 'x', 'y', priorities=(None, 0, 1))
Here the sharding of dim 1 on x
has a higher priority than dim 2 on y
, meaning dim 1 will be propagated through the program first and then dim 2, meaning any potential sharding conflicts will be explicitly avoided by having x
propagated first.
This can be helpful for debugging models as well by having you break down your sharding strategies to separate rounds of propagation in Shardy. For example:
- Priority 0: data parallelism
- Priority 1: megatron
- Priority 2: ZeRO sharding
FAQS
Below is a list of questions you may have on various JAX features and capabilities.
JAX Sharding types
What about GSPMDSharding?
GSPMDSharding
is closely tied to the C++/protobuf representation inside the XLA compiler. As such the type itself won't be supported.
What about PositionalSharding?
This won't be supported. Instead use a NamedSharding
with device_ids
.
PmapSharding
This won't be supported. Shardy is meant for jax.jit
, not jax.pmap
.
Propagation Questions
Section for questions about what you may see during propagation.
What are split Axes in Shardy, aka "x":(2)2?
Refer to "Axis splitting and sub-axes" in Axis splitting and sub-axes