The Shardy (SDY) dialect
The Shardy (SDY) dialect defines an axis-based tensor sharding representation and additional API components to attach shardings to tensors.
Version log: 0.0.1: Add unreduced axes to TensorShardingAttr.
Operations
sdy.all_gather (sdy::AllGatherOp)
Performs an all-gather communication along axes
Syntax:
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Gathers chunks of a tensor along axes specified in gathering_axes.
The gathering_axes is a list of lists of axes. The outer list is over the
dimensions of the tensor. Each inner list specifies the axes along which a
separate gather should be performed on the respective dimension. It will be
applied to the sharding of the operand (tensor) to obtain the sharding of
the result (out_sharding).
Note that out_sharding is not used to determine the sharding of the
result. Instead, the sharding of the result is determined by the sharding of
the operand and the gathering_axes, and out_sharding must match this
inferred sharding.
Example:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- Elements in gathering_axesmust satisfy the constraints listed inAxisRefListAttr.
- Applying gathering_axesto the operand sharding getsout_sharding.
Traits: SameOperandsAndResultType
Interfaces: InferTypeOpInterface, Sdy_CollectiveOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | List of axis ref lists | 
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.all_reduce (sdy::AllReduceOp)
Perform an all-reduce comunication along axes
Syntax:
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Reduces chunks of a tensor along axes specified in reduction_axes.
The order of reduction_axes is not important for the result, but can
affect the order of the corresponding replica groups.
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- reduction_axesmust satisfy the constraints listed in- AxisRefListAttr.
- reduction_axesmust be sorted w.r.t. the mesh.
- The operand sharding and out_shardingmust have equivalent dimension shardings.
- reduction_axesmust not overlap with the operand dimension sharding and replicated axes (it can overlap with unreduced axes).
- reduction_axesmust not overlap with the unreduced axes of- out_sharding. In other words,- out_shardingmust be be replicated along- reduction_axes(implicitly or explicitly).
Traits: SameOperandsAndResultType
Interfaces: CollectiveOpInterface, InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| reduction_axes | ::mlir::sdy::AxisRefListAttr | List of axis refs | 
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.all_slice (sdy::AllSliceOp)
Performs a dynamic-slice operation along axes
Syntax:
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Slices chunks of a tensor along axes specified in slicing_axes. There is
an algebric duality between sdy.all_slice and sdy.all_gather.
The slicing_axes is a list of lists of axes. The outer list is over the
dimensions of the tensor. Each inner list specifies the axes along which a
slice should be performed on the respective dimension. It will be applied to
the sharding of the operand (tensor) to obtain the sharding of the result
(out_sharding).
Note that out_sharding is not used to determine the sharding of the
result. Instead, the sharding of the result is determined by the sharding of
the operand and the slicing_axes, and out_sharding must match this
inferred sharding.
Example:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- Elements in slicing_axesmust satisfy the constraints listed inAxisRefListAttr.
- Applying slicing_axesto the operand sharding getsout_sharding.
Traits: SameOperandsAndResultType
Interfaces: CollectiveOpInterface, InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | List of axis ref lists | 
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.all_to_all (sdy::AllToAllOp)
Performs an all-to-all communication along axes
Syntax:
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
For each (axes, src_dim, tgt_dim) tuple in the parameter list, this
operation slices chunks of a tensor along dimension tgt_dim and axes
specified in axes, scatteres those chunks along the axes, and concatenates
them along dimension src_dim.
This operation is essentially a combination of an all-gather along src_dim
and axes, followed by an all-slice along tgt_dim and axes, i.e., a
suffix of the axes sharding dimension src_dim on the input tensor is
appended to the axes sharding dimension tgt_dim on the output tensor.
The all-to-all will be applied to the sharding of the operand (tensor) to
obtain the sharding of the result (out_sharding).
Note that out_sharding is not used to determine the sharding of the
result. Instead, the sharding of the result is determined by the sharding of
the operand, src_dim, tgt_dim, and axes, and out_sharding must match
this inferred sharding.
Example:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- The parameter list must not be empty.
- For each parameter in params:- Elements in axesmust satisfy the constraints ofAxisRefAttr.
- src_dimand- tgt_dimmust be valid dimensions (non-negative and less than rank of tensor).
- Any src_dimortgt_dimmust be unique across all parameters.
- src_dimmust be sorted in ascending order across all parameters.
 
- Elements in 
- Moving axesfromsrc_dimtotgt_dimin the operand sharding getsout_sharding.
Traits: SameOperandsAndResultType
Interfaces: InferTypeOpInterface, Sdy_CollectiveOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| params | ::mlir::sdy::AllToAllParamListAttr | List of all-to-all parameters | 
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.collective_permute (sdy::CollectivePermuteOp)
Performs a collective-permute communication to replace axes
Syntax:
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Sends a chunk of the input tensor from each device to another to reorder/replace the axes that shard the tensor.
A collective permute can transform the input sharding such that each dimension must be as sharded as it was before, i.e., it must be sharded along axes whose product of sizes matches that of the axes that previously sharded the tensor.
This is useful for reordering axes in a single dimension or across different dimensions, and swapping sharded axes with replicated ones.
In the below example, the sharded tensor size is tensor<1x4x2xf32>, and
that is preserved by the collective permute.
Example:
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- If input and output sharding have different meshes, then those meshes must have exactly the same axes and different order of device ids.
- For each dimension, the product of sharding axis sizes in out_shardingmust match that of the corresponding operand dimension sharding.
Traits: SameOperandsAndResultType
Interfaces: CollectiveOpInterface, InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.constant (sdy::ConstantOp)
Constant operation
Produces an output tensor from a constant value.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Example:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| value | ::mlir::ElementsAttr | constant vector/tensor attribute | 
Results:
| Result | Description | 
|---|---|
| output | statically shaped tensor of any type values | 
sdy.data_flow_edge (sdy::DataFlowEdgeOp)
Data flow edge op.
Syntax:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
A data flow edge of some op X defines a bridge between a set of sources (each is either an operand of X or an operand of X's block terminator) and a set of targets (each is either a result of X or a block argument of X), such that all sources and targets should be sharded in the same way.
An op can have multiple data flow edges that are orthogonal to one another.
For example:
  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })
This while op has n data flow edges, the i-th data flow edges is between
sources x_i, return_value_i and targets y_i, pred_arg_i,
body_arg_i.
An sdy.data_flow_edge takes as input the owner of an edge (can be
any of the targets, but preferably an op result rather than a block
argument), which shouldn't have any other uses. This op isn't pure because
it can take an input that originally didn't have any uses.
The sdy.data_flow_edge also holds an optional sharding for all targets of
the edge, and that sharding should be updated instead of the targets'
sharding (if can be attached) during propagation. This is useful when an op
has many edges, as it's much more efficient to:
- propagate through each edge separately.
- update the sharding of each edge separately instead of all targets at once
(e.g. an op has a single immutable TensorShardingPerValueAttrfor result shardings).
- add each edge to the worklist separately when the sharding of a source has changed.
Propagation will propagate shardings between all sources and targets of a
sdy.data_flow_edge as if it was a regular op with the sources as operands
and targets as results, and an identity sdy.op_sharding_rule. That means
that forward propagation is from sources to targets and backwards
propagation is from targets to sources.
We don't allow the input of a sdy.data_flow_edge to be defined by an
SdyDialect op, so we can assume that it's defined by an op that has
unregistered sdy.sharding attribute.
Traits: SameOperandsAndResultType
Interfaces: InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| input | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.manual_computation (sdy::ManualComputationOp)
Multi-device parallelism operation with manual collectives
Syntax:
operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)
Jump into a region written in terms of per-device local code with explicit collectives, where logical shapes match local per-device physical buffer shapes and collectives correspond exactly to physical cross-device communication.
The body is local wrt the manual_axes. Propagation will occur through the body on any free axes - those not in the manual_axes list.
Note that any unranked tensors are expected to have a sharding with rank 0, i.e. fully replicated.
Constraints:
- Elements in in_shardingsandout_shardingsmust satisfy the constraints listed inTensorShardingAttr.
- The number of global and local tensor inputs/outputs of the op region must match.
- The manual axes must come before any free axes in each dim sharding.
- The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size.
- The global and local shapes of the op regions arguments/results must match.
Traits: IsolatedFromAbove, RecursiveMemoryEffects, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock
Interfaces: ShardableDataFlowOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op | 
| out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op | 
| manual_axes | ::mlir::sdy::ManualAxesAttr | A list of axes that a ManualComputationOp is manual on | 
Operands:
| Operand | Description | 
|---|---|
| tensors | variadic of any type | 
Results:
| Result | Description | 
|---|---|
| results | variadic of any type | 
sdy.mesh (sdy::MeshOp)
Named mesh
Syntax:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Defines a new named mesh. All meshes in a module must have the same number
of devices (except for meshes with a single device_id).
The mesh is a Symbol operation that appears in the module's
SymbolTable and can be referenced by its name.
Traits: HasParent<ModuleOp>
Interfaces: Symbol
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| sym_name | ::mlir::StringAttr | string attribute | 
| mesh | ::mlir::sdy::MeshAttr | Mesh of axes and a list of devices | 
sdy.named_computation (sdy::NamedComputationOp)
Named computation operation
Syntax:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)
Groups a computation, i.e. a block of operations, and gives it a name. Propagation will flow in/out of the region as if everything was inlined.
This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to sdy.named_computation ops, duplicating/copying
the body of the called function into the body of the named_computation.
The type of each block arguments and returned values in the region must be the same as the type of the operands and results type of the op.
Example:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Traits: IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, ShardableDataFlowOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| name | ::mlir::StringAttr | string attribute | 
| in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op | 
| out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op | 
Operands:
| Operand | Description | 
|---|---|
| operands | variadic of any type | 
Results:
| Result | Description | 
|---|---|
| «unnamed» | variadic of any type | 
sdy.propagation_barrier (sdy::PropagationBarrierOp)
Propagation barrier operation
Syntax:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
This op operates like an identity op, outputting the same value it took as input. But in terms of propagation, this will only allow propagation to flow through it in a certain direction.
This prevents shardings from being propagated between the uses of the result of the barrier op and its operand.
- FORWARDmeans shardings can only flow from the operand to the result.
- BACKWARDmeans shardings can only flow from the result to the operand.
- NONEmeans no sharding can propagate through this op.
- Cannot specify BOTH, as this op would be redundant.
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| allowed_direction | ::mlir::sdy::PropagationDirectionAttr | propagation direction enum | 
Operands:
| Operand | Description | 
|---|---|
| input | ranked tensor of any type values | 
Results:
| Result | Description | 
|---|---|
| result | ranked tensor of any type values | 
sdy.reduce_scatter (sdy::ReduceScatterOp)
Performs a reduce-scatter communication along axes
Syntax:
operation ::= `sdy.reduce_scatter` $reduce_scatter_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Reduces chunks of a tensor along axes specified in reduce_scatter_axes,
and then scatters the result along the same axes. This operation is
essentially a combination of an sdy.all_reduce followed by an
sdy.all_slice along the same reduce_scatter_axes.
Constraints:
- Must satisfy the constraints listed in Sdy_CollectiveOpInterface.
- Elements in reduce_scatter_axesmust satisfy the constraints listed inAxisRefListAttr.
- Applying reduce_scatter_axesto the operand sharding getsout_sharding.
Traits: SameOperandsAndResultType
Interfaces: CollectiveOpInterface, InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| reduce_scatter_axes | ::mlir::sdy::ListOfAxisRefListsAttr | List of axis ref lists | 
| out_sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| tensor | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.reshard (sdy::ReshardOp)
Reshards a tensor to a different sharding
Syntax:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Reshards the input tensor with the specified sharding, which is different from the input tensor's existing sharding.
Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their lifespan is:
- Before sharding propagation, ShardingConstraintOp is added by users.
- Sharding propagation consumes ShardingConstraintOp. There is no ShardingConstraintOp in the results of sharding propagation. Instead, ReshardOp may be added if needed.
- A partitioner converts a ReshardOp into a collective op (or an identity op). There should be no ReshardOp in the results of the partitioner.
// TODO(b/331680067). Add a canonicalization pattern to remove redundant // reshard ops.
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| input | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.return (sdy::ReturnOp)
The sdy.return operation terminates the regions attached to
    sdy region-based ops and any other Shardy region-based ops. It is
    variadic: it takes as arguments a list of values whose types can be any (but
    of the same kind, e.g. AnyTensor) and therefore can be reused at various
    levels of the Shardy IR stack.
Syntax:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Traits: AlwaysSpeculatableImplTrait, Terminator
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:
| Operand | Description | 
|---|---|
| results | variadic of any type | 
sdy.sharding_constraint (sdy::ShardingConstraintOp)
Constrains a tensor to the specified sharding
Syntax:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Attaches a sharding to an intermediate tensor (e.g. the result of a matmul) to indicate that this is how that tensor, or a subset of its uses, should be sharded.
If the sharding has open dimensions and unconstraint axes, it means the tensor can be further sharded along the open dimensions.
This op can either:
- Have no uses (dangling) - which means the attached sharding is how the input tensor itself should be sharded.
- Have uses - which means the attached sharding is how the uses of the sharding constraint op should be sharded, while other uses of the input tensor might have a different sharding (if the input tensor has no other uses then the behavior is the same as the no uses case).
Traits: SameOperandsAndResultType
Interfaces: InferTypeOpInterface, SymbolUserOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding | 
Operands:
| Operand | Description | 
|---|---|
| input | shaped of any type values | 
Results:
| Result | Description | 
|---|---|
| result | shaped of any type values | 
sdy.sharding_group (sdy::ShardingGroupOp)
Constrains tensors in the group to have the same sharding.
Syntax:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
This op provides an interface to assign tensors to sharding groups ( groups of tensors that will be enforced to have identical shardings). During propagation, as soon as one group element is sharded, all other members will be sharded in exactly the same way. This operation takes the argument group ID and returns no result, but instead modifies the internal sharding group representation to add the input tensor to the group with the given ID.
Interfaces: InferTypeOpInterface
Attributes:
| Attribute | MLIR Type | Description | 
|---|---|---|
| group_id | ::mlir::IntegerAttr | 64-bit signless integer attribute | 
Operands:
| Operand | Description | 
|---|---|
| input | ranked tensor of any type values | 
Attributes
AllToAllParamAttr
All-to-all parameter
Syntax:
#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>
A tuple containing the axes and source/target dimensions to perform all-to-all on.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| axes | ::llvm::ArrayRef<AxisRefAttr> | the axes to perform all-to-all on | 
| src_dim | int64_t | the source dimension index | 
| tgt_dim | int64_t | the target dimension index | 
AllToAllParamListAttr
List of all-to-all parameters
Syntax:
#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| value | ::llvm::ArrayRef<AllToAllParamAttr> | 
AxisRefAttr
Reference to either a full axis or a split sub-axis
Syntax:
#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>
Constraints:
- namemust be present in the bound- MeshAttr.
- If sub_axis_infois present, it must satisfy the constraints ofSubAxisInfoAttr.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| name | ::llvm::StringRef | name of this axis | 
| sub_axis_info | SubAxisInfoAttr | additional info if this is a sub axis | 
AxisRefListAttr
List of axis refs
Syntax:
#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>
Constraints:
- Elements in valuemust satisfy the constraints ofAxisRefAttr.
- There are no duplicate axis-refs or sub-axes that overlap with one another.
- No two adjacent axis-refs are consecutive sub-axes of that same full axis, i.e., they can be merged into one sub-axis or the full axis.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| value | ::llvm::ArrayRef<AxisRefAttr> | 
AxisToPropagationDetailsAttr
Propagation edge flow details for a specific axis and source.
Syntax:
#sdy.axis_to_propagation_details<
  ::mlir::sdy::AxisRefAttr,   # axis_name
  ::mlir::sdy::EdgeValueRefAttr,   # source
  ::llvm::ArrayRef<EdgeValueRefAttr>   # targets
>
Maps a source value reference to a list of target value references along a particular axis.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| axis_name | ::mlir::sdy::AxisRefAttr | Reference to either a full axis or a split sub-axis | 
| source | ::mlir::sdy::EdgeValueRefAttr | Reference to a particular index of a value edge of type type. | 
| targets | ::llvm::ArrayRef<EdgeValueRefAttr> | list of edge target values | 
DimMappingAttr
List of factor indices for a dimension
An empty list indicates that this is a null mapping (this is parsed/printed
with *), i.e. the dimension isn't mapped to any factors.
Constraints:
- There is at least one factor index.
- Factor indices must be in range [0, $factor_sizes).
- If there are multiple factors, none of them can have size 1.
- No duplicate factor indices.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| factor_indices | ::llvm::ArrayRef<int64_t> | factors this dimension is mapped to | 
DimensionShardingAttr
Dimension sharding
List of axis names to shard a tensor dimension on from major to minor, a boolean indicating whether the dimension can be further sharded, and an optional integer denoting the priority of this dimension sharding, which will respected during sharding propagation. Priorities originate from user sharding annotations and a lower value denotes a higher priority. The highest priority is assumed when the priority is missing in the annotation.
Constraints:
- Elements in axesmust satisfy the constraints listed inAxisRefListAttr.
- If a dimension sharding has a priority:
- The priority is greater than or equal to 0.
- The dimension has at least one axis if it is closed.
 
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| axes | ::llvm::ArrayRef<AxisRefAttr> | axis refs | 
| is_closed | bool | whether this dimension can't be further sharded | 
| priority | std::optional<int64_t> | the priority used during user priority based propagation | 
EdgeValueRefAttr
Reference to a particular index of a value edge of type type.
Syntax:
#sdy.edge_value_ref<
  ::mlir::sdy::EdgeNodeType,   # type
  int64_t   # index
>
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| type | ::mlir::sdy::EdgeNodeType | an enum of type EdgeNodeType | 
| index | int64_t | The integer index (0, 1, 2, etc.) | 
ListOfAxisRefListsAttr
List of axis ref lists
Syntax:
#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| value | ::llvm::ArrayRef<AxisRefListAttr> | 
ManualAxesAttr
A list of axes that a ManualComputationOp is manual on
Syntax:
#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| value | ::llvm::ArrayRef<StringAttr> | 
MeshAttr
Mesh of axes and a list of devices
Syntax:
#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>
A mesh is a list of axes and an optional list of device IDs specifying the device ordering.
If the list of axes is empty
- If the device_idsis not provided, it is an empty mesh.
- If the device_idsis provided, it must be a single non-negative integer, we call it a maximal-sharding mesh.
If the list of axes is provided
- If a device ID list is specified, the product of the axis sizes should match the number of devices.
- If a device ID list is not specified, the implicit device ID list is iota(product(axes)). For simplicity, we also disallow specifying a device ID list that is the same as iota(product(axes)); in this case, a device ID list shouldn't be specified.
- It is not a maximal-sharding mesh even if the total size of axes is 1.
Here are some examples of meshes:
- An empty mesh represents a placeholder mesh that can be replaced during propagation: <[]>
- A mesh without axes list and a single non-negative device ID, which is a maximal-sharding mesh: <[], device_ids=[3]>
- A mesh with two axes and implicit device IDs iota(6): <["a"=2, "b"=3]>
- A mesh with two axes and explicit device IDs specifying the device ordering: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Constraints:
- Elements in device_idsshould be non-negative.
- If axesis empty, the size ofdevice_idscan be 0 (empty mesh) or 1 (maximal-sharding mesh).
- If axesis not empty,- Elements in axesmust not have duplicate names.
- If device_idsis specified, the originaldevice_idsis notiota(product(axis_sizes))and the sorteddevice_idsisiota(product(axis_sizes)).
 
- Elements in 
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| axes | ::llvm::ArrayRef<MeshAxisAttr> | mesh axes | 
| device_ids | ::llvm::ArrayRef<int64_t> | explicit device ordering or maximal device id | 
MeshAxisAttr
Named axis in a mesh
Syntax:
#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| name | ::llvm::StringRef | name | 
| size | int64_t | size of this axis | 
OpShardingRuleAttr
Specifies how an operation can be partitioned.
Syntax:
#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>
A sharding rule specifies how an operation can be partitioned according to various properties on the op - any attributes, the shape of operands, the shape of the results, etc. For example:
%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Note that we allow factors with size 1 even though they cannot be sharded, this is mainly for completeness as many ops such as pointwise ops have size one dimensions that correspond across operands and results.
Factor types:
- reduction_factorscontains the indices of factors requiring reduction, such as the contracting dimensions in a dot operation. These factors can be in operands but not in results.
- need_replication_factorscontains the indices of factors requiring full replication, such as the sorted dimension in a sort operation.
- permutation_factorscontains the indices of factors requiring collective-permute if they are sharded, such as the padding dimensions in a pad operation.
- All other factors are considered as pass-through factors, i.e., factors that don't require any communication if sharded in the same way across all tensors that are mapped to them.
blocked_propagation_factors contains the factors along which shardings are
not allowed to be propagated. It is orthogonal to the factor types. Namely,
a blocked-propagation factor can be any of the factor types.
is_custom_rule describes whether this is a rule defined by a user. Users
can define sharding rules for their custom calls or overwrite the
pre-defined sharding rules for the standard operations. A custom rule is
always preserved/never removed.
Constraints:
- Number of operand/result mappings must match the number of operands/results of the op.
- There is at least one mapping (can't have a rule for an op with no operands/results).
- Rank of each TensorMappingAttrmatches the rank of the corresponding tensor type.
- For each group of factors (reduction_factors,need_replication_factors,permutation_factors):- Elements must be in range [0, $factor_sizes].
- No duplicate factor indices within each group and across groups.
 
- Elements must be in range [0, 
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| factor_sizes | ::llvm::ArrayRef<int64_t> | sizes of all factors in this rule | 
| operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> | operand mappings | 
| result_mappings | ::llvm::ArrayRef<TensorMappingAttr> | result mappings | 
| reduction_factors | ::llvm::ArrayRef<int64_t> | factors requiring reduction | 
| need_replication_factors | ::llvm::ArrayRef<int64_t> | factors requiring full replication | 
| permutation_factors | ::llvm::ArrayRef<int64_t> | factors requiring collective-permute | 
| blocked_propagation_factors | ::llvm::ArrayRef<int64_t> | factors along which shardings are not propagated | 
| is_custom_rule | bool | whether the rule is for a stablehlo.custom_call | 
PropagationEdgesAttr
Propagation edge metadata for all propagation steps.
Syntax:
#sdy.propagation_edges<
  ::llvm::ArrayRef<PropagationOneStepAttr>   # value
>
A list of per-axis propagation details for a value, grouped by step index.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| value | ::llvm::ArrayRef<PropagationOneStepAttr> | 
PropagationOneStepAttr
Per-step propagation metadata.
Syntax:
#sdy.propagation_one_step<
  int64_t,   # step_index
  ::llvm::ArrayRef<AxisToPropagationDetailsAttr>   # axis_entries
>
Propagation details for all axes for a single propagation step.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| step_index | int64_t | step index | 
| axis_entries | ::llvm::ArrayRef<AxisToPropagationDetailsAttr> | Axis propagation details per propagation decision | 
SubAxisInfoAttr
Info about how this sub-axis is derived from the full axis
Syntax:
#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>
When splitting a full axis into n sub-axes, the axis is reshaped into
[k_1,...,k_n], and the ith sub-axis can be expressed by the product of all
axis sizes to its left m=prod(k_1,...,k_(i-1)) (aka pre-size) and size
k_i. Therefore, the sub-axis-info attribute holds those two numbers and is
denoted as follows: (m)k for pre-size m and size k.
Constraints:
- pre-sizeis at least 1.
- sizeis greater than 1.
- pre-sizemust divide the size of the full axis, i.e., both- pre-sizeand- sizedivide the size of the full axis, and the sub-axis doesn't go beyond the full axis.
- The size of the sub-axis isn't equal to the size of the corresponding full axis, in which case the full axis should be used instead.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| pre_size | int64_t | product of sub-axis sizes to the left of this sub-axis | 
| size | int64_t | size of this sub-axis | 
TensorMappingAttr
Factor mappings for each dimension of a tensor.
Syntax:
#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>
Constraints:
- Elements in dim_mappingsmust satisfy the constraints inDimMappingAttr.
- No duplicate factors indices across dimensions.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| dim_mappings | ::llvm::ArrayRef<DimMappingAttr> | dimension mappings | 
TensorShardingAttr
Tensor sharding
Syntax:
#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>,   # replicated_axes
  ::llvm::ArrayRef<AxisRefAttr>   # unreduced_axes
>
A tensor sharding is bound to a specific mesh, and can only reference axis names from that mesh. The dimension shardings tell us for each dimension of the tensor, along which axes (or sub-axes) it is sharded from major to minor. All other axes that don’t shard a dimension are either implicitly or explicitly (if they appear in the list of replicated axes) replicated.
Note that no sharding attribute on a tensor is equivalent to a fully open tensor sharding.
The mesh this sharding is bound to can either be specified by a symbol
name, referencing a corresponding MeshOp symbol, or an inlined MeshAttr.
A sharding can have unreduced axes (specified by unreduced_axes), meaning
the tensor is unreduced along these axes. For example, if the contracting
dimension of a matmul is sharded along axis x in both the lhs and rhs, the
result is unreduced along x. Applying an all-reduce on the tensor along
the unreduced axes will make the tensor replicated along those axes.
However, a tensor with unreduced axes doesn't have to be all-reduced
immediately, it can remain unreduced when passed to linear operations like
stablehlo.add (as long as both lhs and rhs are unreduced) and all-reduced
afterwards. We assume the reduction type is sum, other reductions may be
supported in the future.
Constraints:
- Elements in dim_shardingsmust satisfy the constraints listed inDimensionShardingAttr.
- Elements in replicated_axesmust satisfy the constraints listed inAxisRefListAttr.
- Elements in unreduced_axesmust satisfy the constraints listed inAxisRefListAttr.
- If the corresponding tensor type isn't a ShapedType, the sharding must have rank 0 and no replicated axes.
- If it is a ShapedType, then:- The tensor should have a rank.
- The number of dimension shardings is equal to the rank of the tensor.
- Dimensions of size 0 aren't sharded.
 
- There are no duplicate axis-refs or sub-axes that overlap with one another
across dim_shardings,replicated_axes, andunreduced_axes.
- Items in replicated_axesandunreduced_axesare ordered w.r.t.mesh_or_ref(seeAxisRefAttr::getMeshComparator).
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| mesh_or_ref | ::mlir::Attribute | mesh attr or flat mesh symbol reference attr | 
| dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> | dimension shardings | 
| replicated_axes | ::llvm::ArrayRef<AxisRefAttr> | axis refs | 
| unreduced_axes | ::llvm::ArrayRef<AxisRefAttr> | axis refs | 
TensorShardingPerValueAttr
Tensor sharding per operand/result of an op
Syntax:
#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>
A list of TensorShardingAttrs, one for each operand/result of an op.
Constraints:
- Elements in shardingsmust satisfy the constraints ofTensorShardingAttr.
Parameters:
| Parameter | C++ type | Description | 
|---|---|---|
| shardings | ::llvm::ArrayRef<TensorShardingAttr> | sharding per value | 
Enums
EdgeNodeType
Edge node type enum
Cases:
| Symbol | Value | String | 
|---|---|---|
| OPERAND | 0 | operand | 
| RESULT | 1 | result | 
PropagationDirection
Propagation direction enum
Cases:
| Symbol | Value | String | 
|---|---|---|
| NONE | 0 | NONE | 
| FORWARD | 1 | FORWARD | 
| BACKWARD | 2 | BACKWARD | 
| BOTH | 3 | BOTH |