-mpmd-absorb-inferred-fragments

Root fragments absorb inferred fragments.

Makes root fragments absorb inferred fragments, i.e., by merging inferred producer/consumer fragments into root fragments, where a root fragment is any fragment that is:

  • a user fragment, or
  • not used by any other fragment (e.g. a fragment used by the return op or a transfer only), or
  • not a user of a value produced by any other fragment (e.g., user of block arguments or transfers).

In order to do so, the pass applies the following patterns, until it reaches a fixed point:

(1) Given a root fragment rf, if there is an inferred fragment ipf such that ipf is a producer of rf and rf is the closest consumer of ipf, then ipf is merged into rf.

And dually:

(2) Given a root fragment rf, if there is an inferred consumer icf such that icf is a consumer of rf and rf is the closest producer of icf, then icf is merged into rf.

This means we preserve the structure/shape of the program as defined by the user, via named computations and stage/mesh assignment.

Note that this pass is quite aggressive in merging inferred fragments, and in particular, it could cause small differences across different stages that could increase the number of unique fragments to compile.

This pass will warn us if the final non entry-point functions still include inferred fragments, as these could cause performance issues (e.g., gradient accumulation gone wrong).

Options

-absorb-on-entry-point-function : Whether to absorb inferred fragments into user-defined fragments on entry-point functions, in addition to targets of mpmd.calls.

-mpmd-call-inline

Inlines all mpmd.call operations.

Inlines mpmd.call operations, copying their attributes to any inlined operations.

-mpmd-copy-constants

Copies constants produced in one fragment to their consumers.

Copies constants from their producer fragments to their consumer fragments, possibly through transfers.

Example:

%f = fragment () () {
  return constant
}
%t = transfer %f
fragment (%t) (%arg) {
  op(... %arg ...)
  ...
}

 ~~>

%f = fragment () () {
  return constant
}
%t = transfer %f
fragment (%t) (%arg) {
  %c = constant
  op(... %c ...)
  ...
}

This can be beneficial for runtime performance: it enables potential optimizations by putting the constant together with its users and it avoids transfers of constants. Additionally, it will improve memory usage: we reduce the space needed for parameters of the computation.

-mpmd-erase-unused-callee-block-arguments

Erases any mpmd callee block argument that isn't used by an (hlo) computation.

Erases unused block arguments from functions called by mpmd.calls. We consider a block argument to be unused if it has no used or is used only by the function's terminator, i.e., if it is not used by any hlo computation.

-mpmd-fragment-dce

Eliminates unused fragment arguments/results and simplifies fragment regions.

Removes unused fragment arguments and results, while simplifying fragment regions, essentially eliminating dead MPMD code.

-mpmd-fragment-dedup

Removes any duplicated operands and results in fragments.

Removes any duplicated used arguments and results in fragments. This will leave the duplicate arguments and results unused. Other passes should be run to remove the unused arguments and results.

-mpmd-from-unroll-to-call-counter

Converts the unroll counter attr of a call op to a call counter attr.

Whenever a call op has a unroll_counter attribute, this pass replaces it with a call_counter attribute. This is needed for cases in which a sequence of calls results from unrolling a loop in MLIR (e.g., via -mpmd-unroll-for-loops) instead of from unrolling a loop at the python level.

-mpmd-merge-forward-with-backward

Merge forward fragments with backward fragments.

Merge a producer forward fragment with a consumer backward fragment, if the former is immediately before the latter. This is only true for the last stage in 1F1B schedule, so it will not merge any fragments in previous stages, which is the intended behavior.

-mpmd-merge-inferred-fragments

Merges inferred fragments with user defined fragments.

Merges inferred with user-defined or other inferred fragments. This pass is useful to clean-up/simplify the module and can be useful after other compiler passes that introduce inferred fragments, while -mpmd-transfer-aware-merge above is more invasive and should be used for optimization purposes only.

When clone_inferred_fragments=true, then this merging pass allows for certain fragments to be cloned. In particular, if we encounter a pair of fragments f1 and f2 such that:

  • f2 uses f1, and
  • f1 is inferred, pure, and sufficiently simple (single non-return op and single result), then we merge a clone of f1 into f2, i.e., f1 itself (and other users) remain independent of f2. It may be undesirable to merge inferred producer fragments without cloning, because it can create unnecessary dependencies between fragments. E.g.,
%inferred = frag m1 { return stablehlo.const  }
%frag1 = frag m1 (%inferred, )
%frag2 = frag m1 (%inferred, )

~>

%inferred_frag1 = frag m1 () {  return const_m1,  }
%frag2 = frag m2 (inferred_frag1, )

So frag2 now depends on inferred_frag1 and we create a dependency.

However, sometimes we do want to merge in place, e.g., when the inferred fragment has collectives inside.

Options

-clone-inferred-fragments : Whether to clone inferred fragments. Chains of clonable fragments are merged one-by-one into their consumers and recursively.
-merge-any-consumer       : Whether to merge with any consumer or only the closest consumer.
-merge-sideways           : Whether to merge with the next fragment in the same mesh (neighbor), even if not a consumer.

-mpmd-merge-transfers

Merges sets of transfers that share the same producer and consumer fragments.

Merges sets of transfers of the same payload type which share the same producer and consumer fragments. The payload values of these transfers have less elements than a given threshold, are not sharded and do not live in pinned_host.

Merging a set of transfers means: concatenating the transferred values at producer site and splitting them at consumer site.

-mpmd-merge-user-fragments-into-scheduling-units

Merge user based fragments pre pipeline scheduling passes.

Merges pairs of user defined fragments to be used together with pipeline scheduling passes.

-mpmd-remove-transfer-cycles

Removes device-only transfer cycles from the program, avoiding unnecessary transfers.

Removes transfer cycles.

E.g. in symbols:

x1 = transfer(x0) : m0 -> m1 x2 = transfer(x1) : m1 -> m2 x3 = transfer(x2) : m2 -> m3 x0_1 = transfer(x3) : m3 -> m0 x1_1 = transfer(x0_1) : m0 -> m1

~~>

x1 = transfer(x0) : m0 -> m1 x2 = transfer(x1) : m1 -> m2 x3 = transfer(x2) : m2 -> m3 x0_1 = x0 x1_1 = x1

i.e. we then we break the cycle by using the existing values, removing the unnecessary transfers.

Note that this could increase memory overhead, since transferring the data away and back again means that there's a period where the data isn't on the device. Thus, we only do this if the cycle only contains device-to-device transfers, e.g. since a device -> host -> device cycle could be for memory purposes.

This doesn't use the MLIR Canonicalizer, because that doesn't guarantee that everything is canonicalized, and also it's more expensive to apply.

-mpmd-rule-based-merge

Merges fragments based on user-defined rules.

Merges fragments based on a specified list of rules, each specifying a list of source fragment to merge (by their fragment info) and the target info to label the merged fragment.

Options

-rules : A list of fragment merge rules, each with a list of source fragment infos and a target fragment info.

-mpmd-sink-negligible-ops-into-call-op

Sinks negligible ops into call-ops (i.e., the called function).

Sinks (negligible) ops into called functions: if there is an op with zero operands and a single result which is used as a specific operand of all call ops to the same function, then we sink it into those call ops, i.e., we clone it into the called function and replace all uses of the respective argument with the clone. Sunken ops are removed from the caller function and unused arguments of the callee (and operands of respective call ops) removed. Note that this can potentially duplicate computation across many microbatches, when using call ops for microbatching. Though, this computation is most likely negligible as it takes no operands.

-mpmd-split-and-prioritize-transfer-independent-computations

Splits backward fragments based on transferred results.

Splits a fragment into two fragments, so that we can start computation early. I.e. we split the fragment into two fragments A -> B, where A does not rely on any transfer result, and is maximally large, and B relies on transfer results.

-mpmd-split-bwd-fragments

Splits backward fragments based on transferred results.

Splits backwards fragments so that any computation that does not flow into transferred results becomes a fragment of its own. The original fragment will return some residual values that will be passed in as extra operands to the split-out fragments.

This split allows us to transfer results into other meshes earlier. One canonical use of this optimization will be splitting the activation gradient computation in back-propagation from the parameter gradient computation. Note also that due to the need to potentially thread through some residual values in the new fragments, memory pressure will increase.

-mpmd-uniquify-function-inputs-outputs

Uniquifies any value returned multiple times or any block argument directly returned by the function.

If a function returns the same value multiple times, creates multiple versions for that value, by creating a fragment assigned to that value's mesh which returns the value multiple times. After this pass, each return operand is unique. This is important to ensure that the respective results are allocated in different buffers, as in the following jax.jit example:

def f(x):
  y = x + x
  return y, y

z1, z2 = f(5)
z1 += 1
print(z1) ~~> 6
print(z2) ~~> 5

Similarly, if a function returns a block argument, this pass creates an identity fragment for that block argument, guaranteeing that values are passed by value to the function, not by reference.

-mpmd-unroll-for-loops

Fully unrolls mpmd.for loops.

Creates a pass that completely unrolls mpmd.for ops, while attaching an unroll_counter attribute to each unrolled op.

Requires: the unroll factor to be equal to the number of iterations.

-mpmd-verify-stage-merging

Verifies that merging of fragments assigned to stages succeeded.

Verifies that fragments with stage assignment have been correctly merged. This means that it's not possible to have in the module any two equivalent fragments in terms of assignment and counters.

Two fragments are equivalent in terms of assignment and counters iff a. they are assigned to the same mesh, b. they are assigned to the same stage, c. they have the same transpose count, and d. either both have the same call counter or one of them doesn't have a call counter defined (i.e., an undefined call counter matches any call counter).

This is needed to guarantee to the user that any computation assigned to the same stage is executed contiguously.