-sdy-close-shardings
Closes tensor shardings and drops replicated axes.
-sdy-constant-or-scalar-merger
Merge identical constants and scalar expansions with matching shardings.
Performs a lightweight CSE on constants with identical shardings.
The import pipeline splits and duplicates the constants and scalar expansions such that sharding is not propagated between different uses of a constant sub-computation. If the constants have same shardings after propagation, this pass merges them to save compilation time. See -sdy-constant-or-scalar-splitter for more info.
-sdy-convert-global-to-local
Converts an SDY program from global shapes to local shapes.
Converts an SDY program from global shapes to local shapes by partitioning logical dimensions based on sharding attributes.
This pass leverages a type converter to map RankedTensorType from global logical shapes to device-local physical shapes.
Options
-per-dim-all-gather : Keep per-dimension all-gather without combining them into a single all-gather.
-combine-multi-dimension-reduce-scatter : Combine multi-dimension reduce-scatter into a single reduce-scatter.
-enable-rgv3 : Use StableHLO ReplicaGroupV3 (mesh-axes based) for collectives.
-sdy-drop-sharding-and-mesh
Removes the mesh op and sharding notation from the program.
-sdy-drop-sharding-rules
Drops OpShardingRuleAttr from all registered ops.
-sdy-export-named-computations
Outline calls from NamedComputationOp.
Creates a pass that converts a NamedComputationOp to a CallOp with a new
private function called the NamedComputationOp's name. The new FuncOp
and CallOp have the same shardings as the original NamedComputationOps
operands/results.
If there is a function with the same name as the NamedComputationOp in the
module, the MLIR symbol table will change it to {name}_#.
-sdy-insert-explicit-reshards
Inserts explicit reshards to make all operations have compatible shardings.
A compatible sharding essentially means that the operation can accept the sharded operands and produce a sharded result without requiring any reshard communications (note that the operation might still require communication such as all-reduce or halo-swaps).
After propagation, some operations may still have incompatible shardings.
Note that when an axis (or sub-axis) is used to shard non-corresponding dimensions (e.g. non-contracting dimensions in matmul) across multiple tensors, or when an axis shards a dimension in one tensor but not the corresponding dimension in the other tensor, it is said that the operation has a sharding conflict. Hence, after this pass, the operations become conflict-free.
This pass injects reshard operations explicitly so that, for each operation, corresponding dimensions become sharded in the same way across all operands and results, and every axis (or sub-axis) can only be used to shard a single dimension type.
Example:
Input:
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
Output:
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
In the example above, lhs and rhs are both sharded on axis "x" on their
non-contracting dimensions, which is incompatible. The pass inserts an
explicit reshard on rhs before the dot operation, so that the dot
operation has compatible shardings.
Options
-enable-full-version : Enable full version.
-sdy-insert-func-call-reshards
Inserts reshards for func and call sharding conflicts.
Inserts reshards for func and call sharding conflicts on results.
-sdy-pad-for-divisibility
Pads tensors with non-divisible shardings to divisible shapes.
-sdy-propagate-to-func-results
Propagate shardings from the func terminator to func results.
Copies the shardings of func terminator values to the corresponding func.func results,
except for the main func.
-sdy-remove-all-gather-reduce-scatter-for-cmv1
_Removes sdy.all_gather and sdy.reducescatter for CMV1.
Removes all-gather in the pattern all-gather + dot. Removes reduce-scatter in the pattern dot + reduce-scatter. This pass is for compatibility with collective matmul V1 (CMV1). It is a temporary solution for b/432019089.
-sdy-remove-propagation-debug-info
Removes propagation debug info (propagation edges and origin shardings) during export.
-sdy-remove-sharding-groups
Removes ShardingGroupOps after propagation.
-sdy-remove-sub-axes-in-input-output-shardings
Removes sub-axes in input/output shardings.
Some users of Shardy expect the function inputs/outputs to have shardings
without sub-axes. This pass removes sub-axes and their trailing axes from
input/output open dimension shardings. This pass is usually after
sdy-update-non-divisible-input-output-shardings to ensure that the removal
of sub-axes does not introduce any non-divisible shardings.
-sdy-reshard-to-collectives
Converts ReshardOp into various Shardy collective ops.
Matches reshard ops and rewrites them into various Shardy collective ops. After this pass, no reshard ops remain in the module.
Optionally if keepRedundantReshards is true, the only reshard ops that
remain are the redundant ones. By default it assumes that explicit reshards
have already been inserted (sdy-insert-explicit-reshards) and does not
keep redundant reshards. It should keep redundant reshards if explicit
reshards may have not been already inserted.
Example:
Input:
mesh = <"x"=2, "y"=2, "z"=2>
%0 : tensor<16x2xf32> {sdy.sharding<@mesh, \[{"x", "y", "z"}, {}\]>
%1 = sdy.reshard %arg0 <@mesh, \[{"x"}, {}\]> : tensor<16x2xf32>
Output:
mesh = <"x"=2, "y"=2, "z"=2>
%0 : tensor<16x2xf32> {sdy.sharding<@mesh, \[{"x", "y", "z"}, {}\]>
%1 = sdy.all_gather \[{"y", "z"}, {}\] %arg0 out_sharding=<@mesh, \[{"x"}, {}\]> : tensor<16x2xf32>
In the example above, the tensor %0 : tensor<16x2xf32> is sharded as
\[{"x", "y", "z"}, {}\]. Then, there's a reshard op resharding it as
\[{"x"}, {}\]. On the first axes, since the suffix {"y", "z"} is removed
after the reshard, we infer that we have all-gathered {"y", "z"}. The
second dimension is not changed.
Options
-keep-redundant-reshards : Whether it keeps redundant reshards or removes.
-sdy-resolve-permutation-factors
Resolves sharding on dimensions mapped to kPermutation factors.
Sharding dimensions with kPermutation factors may require cross-device
communication (e.g., halo exchange for windows or collective permutes for
reverses).
If enableHaloExchange is true, the pass would use an available optimized
communication logic to resolve the permutation factors. Otherwise, the pass
would simply insert sdy.reshard ops to replicate those dimensions. The
default of enableHaloExchange is true.
Options
-enable-halo-exchange : Implement halo exchange logic for windowed operations.
-sdy-sharding-constraint-to-reshard
Converts ShardingConstraintOp into ReshardOp.
-sdy-sink-data-flow-edges
Sinks all DataFlowEdgeOp into their input.
Moves the sharding of each DataFlowEdgeOp to its input (the root target of
the edge), and replaces the op with its input.
Options
-sink-debug-sharding-origins : Whether to sink the debug sharding origins info. See `debug-sharding-origins` option in propagation for more info.
-sink-debug-propagation-edge-sharding : Whether to sink the debug propagation edge sharding info. See `debug-propagation-edge-sharding` option in propagation for more info.
-sdy-sink-func-data-flow-edges
Sinks all FuncDataFlowEdgeOp into their input.
Moves the sharding of each FuncDataFlowEdgeOp to its input and replaces
the op with its input.
-sdy-unflatten-call-graph
Unflattens the call graph.
Unflattens the graph. It deduplicates functions with the same input/output shardings and the same origin as desribed by the 'original_func_name' attribute attached to the functions.
Options
-dedup-functions-fully : If true, regardless of the input and output shardings of functions, it keeps one callee function for each caller function. The default is false, meaning it will deduplicate only if the input and output shardings are the same.
-sdy-update-non-divisible-input-output-shardings
Makes FuncOp inputs/outputs evenly sharded, removing any need for padding due to non-divisible shardings.
Users of Shardy expect the function inputs/outputs to be evenly divisible/shardable to avoid requiring padding their tensors. Propagation may make inputs/outputs have non-divisible shardings, so this pass updates them to the largest dimension sharding prefix of the original sharding that is evenly sharded.
-sdy-verify-unreduced-axes
Verifies the consistency of unreduced axis usage.
Verifies that for every operation, if its operands have unreduced axes, the
operation either explicitly reduces them (e.g., via sdy.reshard) or passes
them through to its results (or is a bounding operation like func.call).