The Shardy (SDY) dialect defines an axis-based tensor sharding representation and additional API components to attach shardings to tensors.
Operations
sdy.constant
(sdy::ConstantOp)
Constant operation
Produces an output
tensor from a constant value
.
See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Example:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
value | ::mlir::ElementsAttr | constant vector/tensor attribute |
Results:
Result | Description |
---|---|
output |
tensor of any type values |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
Data flow edge op.
Syntax:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
A data flow edge of some op X defines a bridge between a set of sources (each is either an operand of X or an operand of X's block terminator) and a set of targets (each is either a result of X or a block argument of X), such that all sources and targets should be sharded in the same way.
An op can have multiple data flow edges that are orthogonal to one another.
For example:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
This while op has n data flow edges, the i-th data flow edges is between
sources x_i
, return_value_i
and targets y_i
, pred_arg_i
,
body_arg_i
.
An sdy.data_flow_edge
takes as input the root target of an edge (can be
any of the targets, but preferably an op result rather than a block
argument), which shouldn't have any other uses. This op isn't pure because
it can take an input that originally didn't have any uses.
The sdy.data_flow_edge
also holds an optional sharding for all targets of
the edge, and that sharding should be updated instead of the targets'
sharding (if can be attached) during propagation. This is useful when an op
has many edges, as it's much more efficient to:
- propagate through each edge separately.
- update the sharding of each edge separately instead of all targets at once
(e.g. an op has a single immutable
TensorShardingPerValueAttr
for result shardings). - add each edge to the worklist separately when the sharding of a source has changed.
Propagation will propagate shardings between all sources and targets of a
sdy.data_flow_edge
as if it was a regular op with the sources as operands
and targets as results, and an identity sdy.op_sharding_rule
. That means
that forward propagation is from sources to targets and backwards
propagation is from targets to sources.
We don't allow the input of a sdy.data_flow_edge
to be defined by an
SdyDialect
op, so we can assume that it's defined by an op that has
unregistered sdy.sharding
attribute.
Traits: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding |
Operands:
Operand | Description |
---|---|
input |
shaped of any type values |
Results:
Result | Description |
---|---|
result |
shaped of any type values |
sdy.manual_computation
(sdy::ManualComputationOp)
Multi-device parallelism operation with manual collectives
Syntax:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
Jump into a region written in terms of per-device local code with explicit collectives, where logical shapes match local per-device physical buffer shapes and collectives correspond exactly to physical cross-device communication.
The body is local wrt the manual_axes. Propagation will occur through the body on any free axes - those not in the manual_axes list.
Traits: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
manual_axes | ::mlir::sdy::ManualAxesAttr |
Operands:
Operand | Description |
---|---|
tensors |
variadic of ranked tensor of any type values |
Results:
Result | Description |
---|---|
results |
variadic of ranked tensor of any type values |
sdy.mesh
(sdy::MeshOp)
Named mesh
Syntax:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Defines a new named mesh. All meshes in a module must have the same number
of devices (except for meshes with a single device_id).
The mesh is a Symbol
operation that appears in the module's
SymbolTable
and can be referenced by its name
.
Traits: HasParent<ModuleOp>
Interfaces: Symbol
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
mesh | ::mlir::sdy::MeshAttr | Mesh of axes and a list of devices |
sdy.named_computation
(sdy::NamedComputationOp)
Named computation operation
Syntax:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
Groups a computation, i.e. a block of operations, and gives it a name. Propagation will flow in/out of the region as if everything was inlined.
This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to sdy.named_computation
ops, duplicating/copying
the body of the called function into the body of the named_computation
.
The type of each block arguments and returned values in the region must be the same as the type of the operands and results type of the op.
Example:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Traits: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
, ShardableDataFlowOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
name | ::mlir::StringAttr | string attribute |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
Operands:
Operand | Description |
---|---|
operands |
variadic of any type |
Results:
Result | Description |
---|---|
«unnamed» | variadic of any type |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
Propagation barrier operation
Syntax:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
This op operates like an identity op, outputting the same value it took as input. But in terms of propagation, this will only allow propagation to flow through it in a certain direction.
This prevents shardings from being propagated between the uses of the result of the barrier op and its operand.
FORWARD
means shardings can only flow from the operand to the result.BACKWARD
means shardings can only flow from the result to the operand.NONE
means no sharding can propagate through this op.- Cannot specify
BOTH
, as this op would be redundant.
Traits: AlwaysSpeculatableImplTrait
, Elementwise
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | propagation direction enum |
Operands:
Operand | Description |
---|---|
input |
ranked tensor of any type values |
Results:
Result | Description |
---|---|
result |
ranked tensor of any type values |
sdy.reshard
(sdy::ReshardOp)
Reshards a tensor to a different sharding
Syntax:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Reshards the input tensor with the specified sharding, which is different from the input tensor's existing sharding.
Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their lifespan is:
- Before sharding propagation, ShardingConstraintOp is added by users.
- Sharding propagation consumes ShardingConstraintOp. There is no ShardingConstraintOp in the results of sharding propagation. Instead, ReshardOp may be added if needed.
- A partitioner converts a ReshardOp into a collective op (or an identity op). There should be no ReshardOp in the results of the partitioner.
// TODO(b/331680067). Add a canonicalization pattern to remove redundant // reshard ops.
Traits: AlwaysSpeculatableImplTrait
, Elementwise
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding |
Operands:
Operand | Description |
---|---|
input |
tensor of any type values |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
sdy.return
(sdy::ReturnOp)
The sdy.return
operation terminates the regions attached to
sdy
region-based ops and any other Shardy region-based ops. It is
variadic: it takes as arguments a list of values whose types can be any (but
of the same kind, e.g. AnyTensor
) and therefore can be reused at various
levels of the Shardy IR stack.
Syntax:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Traits: AlwaysSpeculatableImplTrait
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:
Operand | Description |
---|---|
results |
variadic of any type |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
Constrains a tensor to the specified sharding
Syntax:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Attaches a sharding to an intermediate tensor (e.g. the result of a matmul) to indicate that this is how that tensor, or a subset of its uses, should be sharded.
If the sharding has open dimensions and unconstraint axes, it means the tensor can be further sharded along the open dimensions.
This op can either:
- Have no uses (dangling) - which means the attached sharding is how the input tensor itself should be sharded.
- Have uses - which means the attached sharding is how the uses of the sharding constraint op should be sharded, while other uses of the input tensor might have a different sharding (if the input tensor has no other uses then the behavior is the same as the no uses case).
Traits: Elementwise
, SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Tensor sharding |
Operands:
Operand | Description |
---|---|
input |
tensor of any type values |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
sdy.sharding_group
(sdy::ShardingGroupOp)
Sharding group operation
Syntax:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
This op provides an interface to assign tensors to sharding groups ( groups of tensors that will be enforced to have identical shardings). During propagation, as soon as one group element is sharded, all other members will be sharded in exactly the same way. This operation takes the argument group ID and returns no result, but instead modifies the internal sharding group representation to add the input tensor to the group with the given ID.
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
group_id | ::mlir::IntegerAttr | 64-bit signless integer attribute |
Operands:
Operand | Description |
---|---|
input |
ranked tensor of any type values |
Attributes
AxisRefAttr
Reference to either a full axis or a split sub-axis
Syntax:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
name | ::llvm::StringRef |
name |
sub_axis_info | SubAxisInfoAttr |
DimMappingAttr
List of factor indices for a dimension
All factor indices must be in the range [0, num_factors) and an empty list
indicates that this is a null mapping (this is parsed/printed with *
),
i.e. the dimension isn't mapped to any factors.
Parameters:
Parameter | C++ type | Description |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
DimensionShardingAttr
Dimension sharding
List of axis names to shard a tensor dimension on from major to minor, a boolean indicating whether the dimension can be further sharded, and an optional integer denoting the priority of this dimension sharding, which will respected during sharding propagation. Priorities originate from user sharding annotations and a lower value denotes a higher priority. The highest priority is assumed when the priority is missing in the annotation.
Parameters:
Parameter | C++ type | Description |
---|---|---|
axes | ::llvm::ArrayRef<AxisRefAttr> |
list of axis refs |
is_closed | bool |
|
priority | std::optional<int64_t> |
ManualAxesAttr
Syntax:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
value | ::llvm::ArrayRef<StringAttr> |
MeshAttr
Mesh of axes and a list of devices
Syntax:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
A mesh is a list of axes and an optional list of device IDs specifying the device ordering.
If the list of axes is empty, the mesh has an implicit unnamed axis of size 1. In this case, if a device ID list is not provided, the implicit device ID list is [0]; if a device ID list is provided, it must contains a single integer of any non-negative value. We call this maximal-sharding case.
For all non-maximal-sharding cases, if a device ID list is specified, the product of the axis sizes should match the number of devices. If a device ID list is not specified, the implicit device ID list is iota(product(axes)). For simplicity, we also disallow specifying a device ID list that is the same as iota(product(axes)); in this case, a device ID list shouldn't be specified.
Here are some examples of meshes:
- An empty mesh represents a placeholder mesh that can be replaced during propagation: <[]>
- A mesh with an unnamed axis and an explicit device ID, which is typically used to represent maximal sharding: <[], device_ids=[3]>
- A mesh with two axes and implicit device IDs iota(6): <["a"=2, "b"=3]>
- A mesh with two axes and explicit device IDs specifying the device ordering: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Parameters:
Parameter | C++ type | Description |
---|---|---|
axes | ::llvm::ArrayRef<MeshAxisAttr> |
|
device_ids | ::llvm::ArrayRef<int64_t> |
MeshAxisAttr
Named axis in a mesh
Syntax:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
name | ::llvm::StringRef |
name |
size | int64_t |
OpShardingRuleAttr
Specifies how an operation can be partitioned.
Syntax:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
bool # is_custom_rule
>
A sharding rule specifies how an operation can be partitioned according to various properties on the op - any attributes, the shape of operands, the shape of the results, etc. For example:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Note that we allow factors with size 1 even though they cannot be sharded, this is mainly for completeness as many ops such as pointwise ops have size one dimensions that correspond across operands and results.
is_custom_rule
describes whether this is a rule defined by a user for a
stablehlo.custom_call
op. The partitioner doesn't know how to partition
these ops, so a user must tell it how. When it is a custom rule, then the
rule is always preserved/never removed. is_custom_rule
can only be true
for stablehlo.custom_call
ops.
Parameters:
Parameter | C++ type | Description |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
|
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
is_custom_rule | bool |
SubAxisInfoAttr
Info about how this sub-axis is derived from the full axis
Syntax:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
When splitting a full axis into n sub-axes, the axis is reshaped into
[k_1,...,k_n], and the ith sub-axis can be expressed by the product of all
axis sizes to its left m=prod(k_1,...,k_(i-1))
(aka pre-size) and size
k_i. Therefore, the sub-axis-info attribute holds those two numbers and is
denoted as follows: (m)k
for pre-size m and size k.
Parameters:
Parameter | C++ type | Description |
---|---|---|
pre_size | int64_t |
|
size | int64_t |
TensorMappingAttr
Factor mappings for each dimension of a tensor.
Syntax:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
TensorShardingAttr
Tensor sharding
Syntax:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
A tensor sharding is bound to a specific mesh, and can only reference axis names from that mesh. The dimension shardings tell us for each dimension of the tensor, along which axes (or sub-axes) it is sharded from major to minor. All other axes that don’t shard a dimension are either implicitly or explicitly (if they appear in the list of replicated axes) replicated.
The mesh this sharding is bound to can either be specified by a symbol
name, referencing a corresponding MeshOp
symbol, or an inlined MeshAttr
.
Parameters:
Parameter | C++ type | Description |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
mesh attr or flat mesh symbol reference attr |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
|
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
list of axis refs |
TensorShardingPerValueAttr
Tensor sharding per operand/result of an op
Syntax:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
shardings | ::llvm::ArrayRef<TensorShardingAttr> |
Enums
PropagationDirection
propagation direction enum
Cases:
Symbol | Value | String |
---|---|---|
NONE | 0 |
NONE |
FORWARD | 1 |
FORWARD |
BACKWARD | 2 |
BACKWARD |
BOTH | 3 |
BOTH |