Sharding Representation

Background

The purpose of the sharding representation is to specify how a tensor is sharded with respect to a set of available devices.

Sharding representation can either be:

  • Manually specified by the user as sharding constraints on inputs, outputs, or intermediates.
  • Transformed per operation in the process of sharding propagation.

Overview

Basic structure

A logical mesh is a multi-dimensional view of devices, defined by a list of axis names and sizes.

The proposed sharding representation is bound to a specific logical mesh by its name, and can only reference axis names from that mesh. The sharding of a tensor specifies along which axes (of a specific logical mesh), each dimension of the tensor is sharded, ordered from major to minor. The tensor is replicated along all other axes of the mesh.

Let’s explore the sharding representation with a simple rank 2 tensor and 4 devices.

We first reshape the 4 devices [0, 1, 2, 3] into a 2-d array [[0, 1], [2, 3]] to create a mesh with 2 axes:

@mesh_xy = <["x"=2, "y"=2]>

We can then shard the following rank 2 tensor [[a, b], [c, d]] as follows:

Sharding representation of a rank 2 tensor

Other key components

  • Open/Closed dimensions - dimensions can either be open - can be further sharded on available axes; or closed - are fixed and can’t be changed.
  • Explicitly replicated axes - all axes that are not used to shard a dimension are implicitly replicated, but the sharding can specify axes that are explicitly replicated and therefore cannot be used to shard a dimension later on.
  • Axis splitting and sub-axes - a (full) mesh axis can be split into multiple sub-axes that can be individually used to shard a dimension or be explicitly replicated.
  • Multiple logical meshes - different shardings can be bound to different logical meshes, which can have different axes or even a different order of logical device ids.
  • Priorities - to partition a program incrementally, priorities can be attached to dimension shardings, which determine in which order per-dimension sharding constraints will be propagated throughout the module.
  • Dimension sharding divisibility - a dimension can be sharded on axes whose product of sizes doesn’t divide the dimension size.

Detailed Design

We expand the basic structure and each key component in this section.

Basic structure

The dimension shardings tell us for each dimension of the tensor, along which axes (or sub-axes) it is sharded from major to minor. All other axes that don’t shard a dimension are implicitly replicated (or explicitly replicated).

We will start with a simple example and extend it as we describe additional features.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st tensor dimension is sharded along axis "x" and the 2nd tensor dimension is
// sharded along axis "z" then further along axis "y". The local shape of this tensor (i.e. the shape on a single device), would be tensor<2x1xf32>.
sharding<@mesh_xy, [{"x"}, {"z", "y"}]> : tensor<4x8xf32>

Invariants

  • The number of dimension shardings must match the rank of the tensor.
  • All axis names must exist in the referenced mesh.
  • Axes or sub-axes can only appear once in the sharding representation (each one either shards a dimension or is explicitly replicated).

Open/closed dimensions

Each dimension of a tensor can either be open or closed.

Open

An open dimension is open for propagation to further shard it along additional axes, i.e. the specified dimension sharding doesn’t have to be the final sharding of that dimension. This is similar (but not exactly the same as) to

If a dimension is open we add a ? following the axes that the dimension is already sharded on (see example below).

Closed

A closed dimension is one that isn’t available for propagation to add further sharding to, i.e. the specified dimension sharding is the final sharding of that dimension and it can’t be changed. A common use case of this is how GSPMD (usually) doesn’t modify the input/output arguments of a module, or how with jax.jit, the user specified in_shardings are static - they can’t change.

We can extend the example from above to have an open dimension and a closed dimension.

@mesh_xy = <["x"=2, "y"=4, "z"=2]>

// The 1st dimension is closed, therefore it can't be further sharded and {"x"}
// will remain its sharding. The 2nd dimension is open, and can therefore be
// further sharded during propagation, e.g. by "y".
sharding<@mesh_xy, [{"x"}, {"z", ?}]> : tensor<4x8xf32>

Explicitly replicated axes

An explicit set of axes that a tensor is replicated on. While it can be determined that a tensor not sharded on an axis is implicitly replicated on it (like jax.sharding.PartitionSpec today), having it explicit makes sure that propagation cannot use these axes to further shard an open dimension with those axes. With implicit replication, a tensor can be further partitioned. But with explicit replication, nothing can partition the tensor along that axis.

Ordering of replicated axes has no effect on how the data of a tensor is stored. But, for consistency only, the axes will be stored in the order they are specified in the top level mesh. For example, if the mesh is:

@mesh_xy = <["c"=2, "a"=2, "b"=2]>

And we want axes "a" and "c" to be explicitly replicated, the order should be:

replicated={"c", "a"}

We can extend our example from above to have an explicitly replicated axis.

@mesh_xyz = <["x"=2, "y"=4, "z"=2]>

// Since "y" is explicitly replicated, it can't be used to shard the 2nd
// dimension that is open. However, "z" is implicitly replicated so it can be
// used to shard that dimension. The local shape of this tensor (i.e. the
// shape on a single device), would // be tensor<2x8xf32>.
sharding<@mesh_xyz, [{"x"}, {?}], replicated={"y"}> : tensor<4x8xf32>

Axis splitting and sub-axes

A logical mesh of n axes is created by reshaping a 1-dimensional array of devices into an n-dimensional array, where each dimension forms an axis with a user-defined name.

The same process can be done in the compiler to split an axis of size k further into m sub-axes, by reshaping the mesh from [...,k,...] into [...,k1,...,km,...].

Motivation

To understand the motivation behind splitting axes, we will look at the following example:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 : (tensor<8xf32>) -> tensor<2x4xf32>

We want to shard the result of the reshape in a way that would avoid communication (i.e. keep the data where it is). Since the size of "x" is greater than the 1st dimension of the result, we need to split the axis into two sub-axes "x.0" and "x.1" of size 2 each, and shard the 1st dimension on "x.0" and the 2nd dimension on "x.1".

Function Input/output shardings

It is possible that during propagation an input or output of the main function will become sharded along a sub-axis. This can be a problem for some frameworks, where we can’t express such shardings to give back to the user (e.g. in JAX we can’t express sub-axes with jax.sharding.NamedSharding).

We have a few options for dealing with such cases:

  • Allow, and return the sharding in a different format (e.g. jax.sharding.PositionalSharding instead of jax.sharding.NamedSharding in JAX).
  • Disallow, and all-gather sub-axes that shard the input/output.

Currently we allow sub-axes on the inputs/outputs in the propagation pipeline. Let us know if you want a way to disable this.

Representation

In the same way that we can reference specific full axes from the mesh by their name, we can reference specific sub-axes by their size and the product of all sub-axis (of the same axis name) sizes to their left (that are major to them) .

To extract a specific sub-axis of size k from a full axis "x" of size n, we effectively reshape the size n (in the mesh) into [m, k, n/(m*k)] and use the 2nd dimension as the sub-axis. A sub-axis can thus be specified by two numbers, m and k, and we use the following concise notation to denote sub-axes: "x":(m)k.

  • m>=1 is the pre-size of this sub-axis (m should be a divisor of n). The pre-size is the product of all sub-axis sizes to the left of (that are major to) this sub-axis (if equal to 1 it means there are none, If larger than 1 it corresponds to a single or multiple sub-axes).

  • k>1 is the actual size of this sub-axis (k should be a divisor of n).

  • n/(m*k) is the post-size. It is the product of all sub-axis sizes to the right of (that are minor to) this sub-axis (if equal to 1 it means there are none, If larger than 1 it corresponds to a single or multiple sub-axes).

However, the number of other sub-axes doesn’t make a difference when using a specific sub-axis "x":(m)k, and any other sub-axis doesn’t need to be referenced in the tensor sharding if it doesn’t shard a dimension or is explicitly replicated.

Going back to the example in Motivation section, we can shard the result as follows:

@mesh_x = <["x"=4]>

%arg0 : tensor<8xf32> {sdy.sharding=<@mesh_x, [{"x"}]>}
%0 = reshape %arg0 {sdy.sharding_per_value=<[<@mesh_x, [{"x":(1)2}, {"x":(2)2}]>]>}
    : (tensor<8xf32>) -> tensor<2x4xf32>

Here is another example of a split axis where only some of its sub-axes are used.

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Axis "y" is effectively split into 3 sub-axes denoted as
//   "y":(1)2, "y":(2)2, "y":(4)2
// in order, but only "y":(2)2 is used, to shard the 2nd dimension. The local
// shape of this tensor (i.e. the shape on a single device), would be
// tensor<2x4xf32>.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}]> : tensor<4x8xf32>

Similarly, the following two shardings are semantically equivalent. We can think of mesh_xy as a splitting of mesh_full.

@mesh_full = <"devices"=8>
@mesh_xy = <"x"=4, "y"=2>

sharding<@mesh_xy, [{"x"},{ "y"}]> : tensor<4x4xf32>
sharding<@mesh_full, [{"devices":(1)4}, {"devices":(4)2}]> : tensor<4x4xf32>

Explicitly replicated sub-axes

In addition to sub-axes being used to shard dimension, they can also be marked as explicitly replicated. We allow this in the representation because sub-axes behave just like full axes, i.e. when you shard a dimension along a sub-axis of axis "x", the other sub-axes of "x" are implicitly replicated, and therefore can be explicitly replicated to indicate that a sub-axis must stay replicated and can’t be used to shard a dimension.

For example:

@mesh_xyz = <["x"=2, "y"=8, "z"=2]>

// Sub-axis "y":(1)2 is explicitly replicated and "y":(4)2 is implicitly replicated.
sharding<@mesh_xyz, [{"x"}, {"y":(2)2}], replicated={"y":(1)2}> : tensor<4x8xf32>

Replicated sub-axis of the same full axis should be ordered in increasing order by their pre-size, for example:

replicated={"y":(4)2, "x", "y":(1)2} ~> replicated={"x", "y":(1)2, "y":(4)2}

Invariants

  • Sub-axes referenced in a tensor sharding must not overlap, e.g. "x":(1)4 and "x":(2)4 overlap.

  • Sub-axes referenced in a tensor sharding must be as big as possible, i.e. If a dimension sharding has two adjacent sub-axes A and B in order, or sub-axes A and B are explicitly replicated, they must not be consecutive, e.g. "x":(1)2 and "x":(2)4 as they can be replaced with a single "x":(1)8.

Multiple logical meshes

One logical mesh is a multi-dimensional view of devices. We may need multiple views of the devices to represent our shardings, especially for arbitrary device assignments.

For example, jax.sharding.PositionalSharding does not have one common logical mesh. GSPMD currently supports that with HloSharding, where the representation can be an ordered list of devices and dimension sizes, but this can’t be represented with the axis splitting above.

We overcome this limitation and handle existing corner cases by defining multiple logical meshes at the top level of the program. Each mesh can have a different number of axes with different names, as well as its own arbitrary assignment for the same set of devices, i.e. each mesh refers to the same set of devices (by their unique logical ID) but with an arbitrary order, similar to the GSPMD representation.

Each sharding representation is linked to a specific logical mesh, therefore it will only reference axes from that mesh.

A tensor that is assigned to one logical mesh can be used by an op that is assigned to a different mesh, by naively resharding the tensor to match the destination mesh. In GSPMD this is what is usually done to resolve conflicting meshes.

We provide two examples below:

Users can specify multiple meshes with different named axes (e.g. via jax.sharding.NamedSharding), that have the same order of devices. In this example, <@mesh_0, "b"> is identical to <@mesh_1, "z">.

@mesh_0 = {<["a"=4, "b"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}
@mesh_1 = {<["x"=2, "y"=2, "z"=2]>, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]}

Priorities

Priority is a way to prioritize certain partitioning+propagation decisions over others, and allows for incremental partitioning of a program.

Priorities are values attached to some or all dimensions of a sharding representation (replicated axes don’t have priorities).

For example:

@mesh_xy = <["w"=6, "x"=2, "y"=4, "z"=2]>

//                                    |-> y is implicitly p0
%arg4 : sharding<@mesh_xy, [{"x"}p1, {"y"}, {"z",?}p2], replicated={} }>

Priorities give users more fine grained control over propagation, e.g., batch parallelism first, then megatron, and finally ZeRO sharding. This allows for strong guarantees about what’s partitioned and allows for better debuggability by having more fine grained sharding strategies (can see how the program looks after just megatron in isolation).

We allow attaching a priority to each dimension sharding (0 by default), which indicates that all shardings with priority <i will be propagated to the entire program before shardings with priority i.

Even if a sharding has an open dimension with lower priority, e.g., {"z",?}p2, it won’t be overridden by another tensor sharding with a higher priority during propagation. However, such an open dimension can be further sharded after all higher priority shardings have been propagated.

In other words, priorities are NOT about which dimension sharding is more important than another - it’s the order in which distinct groups of dimension shardings should propagate to the entire program, and how conflicts on intermediate, unannotated tensors should be resolved.

Invariants

  • Priorities start at 0 (highest priority) and increase (to allow users to add and remove priorities easily, we allow gaps between priorities, e.g., p0 and p2 are used but p1 isn’t).

  • An empty closed dimension sharding (i.e., {}), shouldn’t have a priority, as this won’t have any effect.

Dimension sharding divisibility

It’s possible for a dimension of size d to be sharded along axes whose product of sizes is n, such that d is not divisible by n (which in practice would require the dimension to be padded).

For example:

@mesh_xy = <["x"=8, "y"=2, "z"=3]>

sharding<@mesh_xy, [{"x"}, {"y"}, {"z"}]> : tensor<7x3x8xf32>

Grammar

Each logical mesh is defined as follows:

@mesh_name = <mesh_axis_1,...,mesh_axis_n>

mesh_axis ::= axis_name=axis_size

axis_name ::= str
axis_size ::= int

The sharding representation will have the following structure for a tensor of rank r:

sharding<@mesh_name, dim_shardings, replicated=replicated_axes}

mesh_name ::= str
dim_shardings ::= [dim_sharding_1,...,dim_sharding_r]
replicated_axes ::= {axis_1,...,axis_m}

dim_sharding ::=
  {axis_1,...,axis_k} |  // closed dimension
  {axis_1,...,axis_k,?}  // open dimension

axis ::=
  axis_name  |   // a full axis
  sub_axis             // a sub axis

axis_name ::= str

sub_axis ::= axis_name:(pre_size)size
pre_size ::= int
size ::= int