-mpmd-convert-sdy-constants
Converts sdy.constant
ops into stablehlo.constant
.
Converts any sdy.constant
op, that isn't foldable, into a
stablehlo.constant
op. There is no reason to prevent constant folding
since we are stripping shardings away from constants in
mpmd-convert-sdy-shardings-to-mpmd-types
.
-mpmd-convert-sdy-shardings-to-mpmd-types
Moves shardings from op attrs to !mpmd.mesh_tensor
types.
Moves shardings from the attributes of MPMD ops (e.g. fragments, transfer)
to the MeshTensorType
of their results. Assuming we apply SDY propagation
before this pass, the SPMD shardings are attached to the op's attributes.
This pass moves the sharding to MeshTensorType
s since later passes require
the type to contain a sharding.
This pass also removes any sharding from ops that don't have a
MeshTensorType
, i.e., ops inside mpmd.fragment
ops.
-mpmd-enforce-user-shardings
Enforces the user specified shardings for inputs and outputs.
Enforces the input and output shardings of fragments that take function arguments or produce function results respectively, to be the ones specified by the user, i.e., the input and outputs shardings of the function.
After this pass, fragment and transfer users of function arguments and producers of function results should have the same shardings as the ones specified by the user. If the user did not specify a sharding for an input or output, this pass keeps the sharding that propagation assigned.
Precondition:
- The user shardings are set on the function's arguments and results as attributes.
- The fragment shardings are set on
in_shardings
andout_shardings
attributes.
-mpmd-extract-reshards-from-inter-mesh-transfers
Moves SPMD resharding around an inter-mesh transfer to inside a fragment.
Ensures that all inter-mesh transfers do not (SPMD) reshard the array (their in and out shardings are the same), by updating the types of producer/consumer fragments or by creating inferred fragments for non-fragment producers/consumers.
This is needed as MPMD runtimes have limitations w.r.t. supported reshardings.
This pass is only applied to MPMD functions in global view and with a homogeneous topology.
Precondition: all shardings are specified as op attributes and not in types.