Shardy Guide for JAX Users

Open in Colab

Shardy is a new propagation system being introduced into the XLA stack, and below we want to introduce any JAX users to:

  1. What has changed in JAX
  2. Why Shardy?
  3. Future plans

This is meant for JAX users who use jax.jit for running training/inference models across more than 1 GPU or TPU (batch parallelism, megatron, ZeRO, etc). They would be using things like PartitionSpecs and NamedShardings.

1. What has changed in JAX?

State of JAX before: GSPMD

Prior to Shardy, JAX users who partitioned their models across models across multiple devices used GSPMD behind the scenes.

GSPMD is the propagation+partitioning system that lives in the middle of the XLA pipeline. It operates on HLO - the IR that comes after StableHLO (the program you get after running jax.jit.lower).

JAX doesn't run GSPMD directly, but encodes instructions into the StableHLO IR for GSPMD to read later on.

But before we go any further, let's introduce our working example.

Make sure you have installed jax>=0.4.35.

pip install jax==0.4.35

Imports and utilities

import os
# make sure our code runs on 8 devices
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental.shard_map import shard_map

First, let's create our mesh.

mesh = Mesh(
    np.reshape(np.array(jax.devices()), (4, 2)),
    ('data', 'model'))

print(mesh.shape)
OrderedDict([('data', 4), ('model', 2)])

In/Out shardings

Let's look at what changed the most: how sharding attributes are encoded in the JAX program for the compiler to read.

Let's look at it through an example. It's going to be an MLP-like model consisting of no bias tensors, and 2 layers (two matmuls).

def predict(x, w1, w2):
  x = jnp.tanh(x)
  z1 = jnp.einsum('ij,jk->ik', x, w1)
  z2 = jnp.einsum('ij,jk->ik', z1, w2)
  return jnp.sin(z2)

What we will want to do here sharding wise is:

  1. data parallelism on x
  2. tensor parallelism on w1 and w2 through the megatron sharding strategy.

Now let's prepare the model for GSPMD sharding. Note that we will explicitly shard w1, but let GSPMD propagation shard w2.

def run_in_out_shardings():
  samples = jax.ShapeDtypeStruct((16, 128), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec('data', None)))
  samples_sharding = NamedSharding(mesh, PartitionSpec('data', None))
  w1 = jax.ShapeDtypeStruct((128, 256), jnp.float32, sharding=NamedSharding(mesh, PartitionSpec(None, 'model')))
  w1_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
  w2 = jax.ShapeDtypeStruct((256, 10), jnp.float32)
  w2_sharding = None

  print(jax.jit(predict, in_shardings=(samples_sharding, w1_sharding, w2_sharding)).lower(samples, w1, w2).as_text())

run_in_out_shardings()
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16x128xf32> {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"}, %arg1: tensor<128x256xf32> {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) {
    %0 = stablehlo.tanh %arg0 : tensor<16x128xf32>
    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32>
    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32>
    %3 = stablehlo.sine %2 : tensor<16x10xf32>
    return %3 : tensor<16x10xf32>
  }
}

GSPMD's sharding annotations look like the following:

JAX sharding GSPMD sharding
NamedSharding(mesh, PartitionSpec('data', None)) {devices=[4,1,2]<=[8] last_tile_dim_replicate}
NamedSharding(mesh, PartitionSpec(None, 'model')) {devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}
None nothing

None is no sharding as expected since GSPMD will populate this during sharding propagation.

Notice how all the axis names go away? While there is a 1:1 correspondance between NamedSharding and GSPMD sharding, as a reader, it can be difficult to read. It is only more difficult once you introduce various axis names.

Let's look at Shardy for comparison. To enable Shardy in JAX, simply enable the flag:

jax.config.update("jax_use_shardy_partitioner", True)
run_in_out_shardings()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_predict attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<16x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"data"}, {}]>}, %arg1: tensor<128x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"model"}]>}, %arg2: tensor<256x10xf32>) -> (tensor<16x10xf32> {jax.result_info = ""}) {
    %0 = stablehlo.tanh %arg0 : tensor<16x128xf32>
    %1 = stablehlo.dot_general %0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x128xf32>, tensor<128x256xf32>) -> tensor<16x256xf32>
    %2 = stablehlo.dot_general %1, %arg2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x256xf32>, tensor<256x10xf32>) -> tensor<16x10xf32>
    %3 = stablehlo.sine %2 : tensor<16x10xf32>
    return %3 : tensor<16x10xf32>
  }
}

Now we have

JAX sharding Shardy sharding
NamedSharding(mesh, PartitionSpec('data', None)) #sdy.sharding<@mesh, [{"data"}, {}]>
NamedSharding(mesh, PartitionSpec(None, 'model')) #sdy.sharding<@mesh, [{}, {"model"}]>
None nothing

Shardy's representation is a lot closer to what JAX NamedShardings are like. So when looking at a file dump of your program after propagation, it will be a lot easier to understand what is going on since the correspondance is a lot closer to JAX.

Note that instead of the total devices/axes living on the sharding, they live on a top level @mesh value.

jax.lax.with_sharding_constraint

GSPMD currently lowers it to a custom call:

def run_with_sharding_constraint():
  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)

  def f(x):
    return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, PartitionSpec('data', PartitionSpec.UNCONSTRAINED)))

  print(jax.jit(f).lower(x).as_text())

run_with_sharding_constraint()
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "unspecified_dims=[1]", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

But under Shardy it's an explicit op:

jax.config.update("jax_use_shardy_partitioner", True)
run_with_sharding_constraint()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = sdy.sharding_constraint %arg0 <@mesh, [{"data"}, {?}]> : tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

Note that UNCONSTRAINED under GSPMD has the custom call have an op attribute backend_config = "unspecified_dims=[1]". But under Shardy, it makes dim 1 be {?}. In Shardy, dimension shardings without a ? are closed, meaning that dimension can't be further sharded, but when it has a trailing ?, it can be further sharded. Refer to Sharding representation for more info on the sharding representation.

jax.experimental.shard_map

Under GSPMD this is a few different custom calls with various shard_map specific attributes on the GSPMD sharding. Let's look where the model axis is auto, meaning it's free to be used inside the body of the shard_map by sharding constraints.

def run_shard_map():
  x = jax.ShapeDtypeStruct((32, 64), jnp.float32)

  def body(x):
    return jax.lax.all_gather(x, 'data', tiled=True)

  shmaped_f = shard_map(
        body,
        mesh=mesh,
        in_specs=(jax.sharding.PartitionSpec('data',),),
        out_specs=jax.sharding.PartitionSpec(),
        check_rep=False)

  print(jax.jit(shmaped_f).lower(x).as_text())

print(run_shard_map())
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<8x64xf32>
    %2 = call @shmap_body(%1) : (tensor<8x64xf32>) -> tensor<32x64xf32>
    %3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %4 : tensor<32x64xf32>
  }
  func.func private @shmap_body(%arg0: tensor<8x64xf32>) -> (tensor<32x64xf32> {jax.result_info = "[None, None]"}) {
    %0 = "stablehlo.all_gather"(%arg0) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

None

With the custom calls and GSPMD sharding, it's getting pretty confusing. Let's look at what Shardy gives:

jax.config.update("jax_use_shardy_partitioner", True)
run_shard_map()
jax.config.update("jax_use_shardy_partitioner", False)
module @jit_body attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["data"=4, "model"=2]>
  func.func public @main(%arg0: tensor<32x64xf32>) -> (tensor<32x64xf32> {jax.result_info = ""}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"data"}, {}]>] out_shardings=[<@mesh, [{}, {}]>] manual_axes={"data", "model"} (%arg1: tensor<8x64xf32>) {
      %1 = "stablehlo.all_gather"(%arg1) <{all_gather_dim = 0 : i64, channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<8x64xf32>) -> tensor<32x64xf32>
      sdy.return %1 : tensor<32x64xf32>
    } : (tensor<32x64xf32>) -> tensor<32x64xf32>
    return %0 : tensor<32x64xf32>
  }
}

We now:

  • Have a single op called sdy.manual_computation which holds:
    • the in_specs
    • the out_specs
    • the body of the shard_map
    • the inverse of the auto axes which we call manual_axes

A lot easier to read!

Auto partitioners

In progress.

XLA_DUMP_TO

When specifying the XLA_DUMP_TO, you will see an additional shardy/ directory containing various dumps of the StableHLO program. A lot of them are currently only relevant to the Shardy team to debug issues. The one you should focus on when debugging is sdy_module_after_sdy_export.mlir which is the module after propagation finished on the StableHLO program.

2. Why Shardy?

Readability

As seen above, it's much easier to read the shardings and shard_maps and understand how they match what is happening in the JAX code. Similarly GSPMD propagation will give back HLO code - not MLIR which both Shardy and jax.jit.lower return.

Interpretability

We are planning on exposing a feature we call "user priorities" (not in JAX yet!). It allows you to attach a value telling Shardy how important a tensor's dimension sharding is over other constraints in the program.

Higher prioritied are defines as lower values (lowest being 0, think of it as a p0 priority tasks).

PartitionSpec(None, 'x', 'y', priorities=(None, 0, 1))

Here the sharding of dim 1 on x has a higher priority than dim 2 on y, meaning dim 1 will be propagated through the program first and then dim 2, meaning any potential sharding conflicts will be explicitly avoided by having x propagated first.

This can be helpful for debugging models as well by having you break down your sharding strategies to separate rounds of propagation in Shardy. For example:

  • Priority 0: data parallelism
  • Priority 1: megatron
  • Priority 2: ZeRO sharding

FAQS

Below is a list of questions you may have on various JAX features and capabilities.

JAX Sharding types

What about GSPMDSharding?

GSPMDSharding is closely tied to the C++/protobuf representation inside the XLA compiler. As such the type itself won't be supported.

What about PositionalSharding?

This won't be supported. Instead use a NamedSharding with device_ids.

PmapSharding

This won't be supported. Shardy is meant for jax.jit, not jax.pmap.

Propagation Questions

Section for questions about what you may see during propagation.

What are split Axes in Shardy, aka "x":(2)2?

Refer to "Axis splitting and sub-axes" in Axis splitting and sub-axes