-sdy-add-data-flow-edges

Inserts DataFlowEdgeOp for every data-flow edge.

Inserts DataFlowEdgeOp for every value that is the owner of a data-flow edge, i.e., all values returned by getDataFlowEdgeOwners on every op in the module.

The inserted DataFlowEdgeOp will take the existing sharding of the owner target if it exists.

TODO(b/330339693): update this doc when getDataFlowEdgeOwners is removed.

-sdy-apply-sharding-constraints

Applies constraints that dictate the sharding of their input.

Copies the sharding of a ShardingConstraintOp to its input if it satisfies all of the following:

  • The input doesn't have an existing sharding.
  • The input isn't produced by a DataFlowEdgeOp, which holds the sharding of all targets of the edge.
  • The sharding of the ShardingConstraintOp is fully closed.
  • The input doesn't have any other users of type ShardingConstraintOp or ManualComputationOp with a different sharding.

Which indicates that the ShardingConstraintOp dictates the sharding of its input.

Note that the sharding of a ShardingConstraintOp will propagate to its input or users during propagation regardless of this pass, but since the closed property of a dimension doesn't propagate, it's important to copy the sharding to fully respect the constraint in the above cases.

In addition, if a tensor is used by a chain of ShardingConstraintOps that satisfy all of the following:

  • The tensor isn't produced by a ShardingConstraintOp and doesn't have any other users of type ShardingConstraintOp or ManualComputationOp.
  • None of the ShardingConstraintOps in the chain have more than one use except the last one.
  • The last ShardingConstraintOp in the chain doesn't have any users of type ShardingConstraintOp or ManualComputationOp (otherwise it's not the last in the chain).

then this pass replaces all other uses of the input of the chain, that are defined after the last ShardingConstraintOp in the chain (and within the same block), with the result of the chain, as it should dictate the sharding of those uses.

-sdy-constant-splitter

Splits constant sub-computations so each has a single use.

Splits constant sub-computations such that they have a single user.

This ensures that a sharding isn't propagated between different uses of a constant sub-computation, as this is considered a false dependency (the uses of a constant shouldn't be sharded in the same way just because they use the same constant). In effect, each use can have a different sharding that can propagate in isolation to its own copy of the constant sub-computation.

A constant sub-computation is either:

  • a constant or iota op (no operands)
  • a broadcast, slice, or pure element-wise op, whose operands are all defined by constant sub-computations (recursively), along with the entire sub-computations that define its operands.

Note that within a constant sub-computation, a value can have multiple uses within that sub-computation.

-sdy-lift-inlined-meshes

Lifts inlined MeshAttrs in shardings as symbol MeshOps.

Replaces any inlined MeshAttr in a TensorShardingAttr with a mesh symbol name, referencing either an existing or new MeshOp in the module, such that no two MeshOps with an identical MeshAttr (existing MeshOps are deduped as well).

The name of each new MeshOp will either be:

  • maximal_mesh_{device-id}, for a maximal mesh (i.e., empty axis list and a single device ID).
  • The first available name in [mesh, mesh_0, mesh_1, ...], otherwise. ### -sdy-manual-axes-cleanup

Cleans up the use of manual axes in ManualComputationOps

1) For any in/out sharding that hasn't specified a manual axis, add that manual axis to its replicated_axes. This is to ensure manual axes are always fully specified.

2) Sorts the manual axes in mesh axis declaration order.

-sdy-sharding-group-import

Canonicalization and validation pass for sharding groups.

Applies canonicalization and validation to sharding groups upon import. Namely these are:

1) Sharding Group Unification - Combines sharding groups using the transitive property of group membership. Any time that a tensor T is in a sharding group G1 and sharding group G2, then we can infer that all members in G1 and G2 should be sharded in the same way. Thus we can combine G1 and G2 into a single group. The set of canonical group ids after merging will be 0,1,...N-1 for the minimum set of groups.

2) Sharding Group Validation Validates that sharding groups are well formed and conform to assumptions within the implementation. This currently asserts that if a sharding group contains a Value defined inside the block of a ManualComputationOp, then all other values in that group must reside in the same block.