-sdy-close-shardings

Closes tensor shardings and drops replicated axes.

-sdy-insert-explicit-reshards

Inserts explicit reshards to make all operations have compatible shardings.

A compatible sharding essentially means that the operation can accept the sharded operands and produce a sharded result without requiring any reshard communications (note that the operation might still require communication such as all-reduce or halo-swaps).

After propagation, some operations may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding dimensions (e.g. non-contracting dimensions in matmul) across multiple tensors, or when an axis shards a dimension in one tensor but not the corresponding dimension in the other tensor, it is said that the operation has a sharding conflict. Hence, after this pass, the operations become conflict-free.

This pass injects reshard operations explicitly so that, for each operation, corresponding dimensions become sharded in the same way across all operands and results, and every axis (or sub-axis) can only be used to shard a single dimension type.

A clarifying example:

Input:

mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
  : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>

Output:

sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
  : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>

In the example above, there is a conflict since lhs and rhs tensors are both sharded on axis "x" on their non-contracting dimensions. Here, rhs tensor is resharded, before the dot operation, explicitly to be sharded only on its first dimension and on axis "x". This way, the dot operation becomes compatible.

-sdy-remove-sharding-groups

Removes ShardingGroupOps after propagation.

-sdy-sharding-constraint-to-reshard

Converts ShardingConstraintOp into ReshardOp.

-sdy-sink-data-flow-edges

Sinks all DataFlowEdgeOp into their input.

Moves the sharding of each DataFlowEdgeOp to its input (the root target of the edge), and replaces the op with its input.

TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached.

-sdy-update-non-divisible-input-output-shardings

Makes FuncOp inputs/outputs evenly sharded, removing any need for padding due to non-divisible shardings.

Users of Shardy expect the function inputs/outputs to be evenly divisible/shardable to avoid requiring padding their tensors. Propagation may make inputs/outputs have non-divisible shardings, so this pass updates them to the largest dimension sharding prefix of the original sharding that is evenly sharded.