-mpmd-absorb-inferred-fragments
Root fragments absorb inferred fragments.
Makes root fragments absorb inferred fragments, i.e., by merging inferred producer/consumer fragments into root fragments, where a root fragment is any fragment that is:
- a user fragment, or
- not used by any other fragment (e.g. a fragment used by the return op or a transfer only), or
- not a user of a value produced by any other fragment (e.g., user of block arguments or transfers).
In order to do so, the pass applies the following patterns, until it reaches a fixed point:
(1) Given a root fragment rf
, if there is an inferred fragment ipf
such
that ipf
is a producer of rf
and rf
is the closest consumer of ipf
,
then ipf
is merged into rf
.
And dually:
(2) Given a root fragment rf
, if there is an inferred consumer icf
such
that icf
is a consumer of rf
and rf
is the closest producer of icf
,
then icf
is merged into rf
.
This means we preserve the structure/shape of the program as defined by the user, via named computations and stage/mesh assignment.
Note that this pass is quite aggressive in merging inferred fragments, and in particular, it could cause small differences across different stages that could increase the number of unique fragments to compile.
This pass will warn us if the final non entry-point functions still include inferred fragments, as these could cause performance issues (e.g., gradient accumulation gone wrong).
Options
-absorb-on-entry-point-function : Whether to absorb inferred fragments into user-defined fragments on entry-point functions, in addition to targets of mpmd.calls.
-mpmd-call-inline
Inlines all mpmd.call
operations.
Inlines mpmd.call
operations, copying their attributes to any inlined
operations.
-mpmd-copy-constants
Copies constants produced in one fragment to their consumers.
Copies constants from their producer fragments to their consumer fragments, possibly through transfers.
Example:
%f = fragment () () {
return constant
}
%t = transfer %f
fragment (%t) (%arg) {
op(... %arg ...)
...
}
~~>
%f = fragment () () {
return constant
}
%t = transfer %f
fragment (%t) (%arg) {
%c = constant
op(... %c ...)
...
}
This can be beneficial for runtime performance: it enables potential optimizations by putting the constant together with its users and it avoids transfers of constants. Additionally, it will improve memory usage: we reduce the space needed for parameters of the computation.
-mpmd-erase-unused-callee-block-arguments
Erases any mpmd callee block argument that isn't used by an (hlo) computation.
Erases unused block arguments from functions called by mpmd.calls. We consider a block argument to be unused if it has no used or is used only by the function's terminator, i.e., if it is not used by any hlo computation.
-mpmd-fragment-dce
Eliminates unused fragment arguments/results and simplifies fragment regions.
Removes unused fragment arguments and results, while simplifying fragment regions, essentially eliminating dead MPMD code.
-mpmd-fragment-dedup
Removes any duplicated operands and results in fragments.
Removes any duplicated used arguments and results in fragments. This will leave the duplicate arguments and results unused. Other passes should be run to remove the unused arguments and results.
-mpmd-from-unroll-to-call-counter
Converts the unroll counter attr of a call op to a call counter attr.
Whenever a call op has a unroll_counter
attribute, this pass replaces it
with a call_counter
attribute. This is needed for cases in which a
sequence of calls results from unrolling a loop in MLIR (e.g., via
-mpmd-unroll-for-loops
) instead of from unrolling a loop at the python
level.
-mpmd-merge-forward-with-backward
Merge forward fragments with backward fragments.
Merge a producer forward fragment with a consumer backward fragment, if the former is immediately before the latter. This is only true for the last stage in 1F1B schedule, so it will not merge any fragments in previous stages, which is the intended behavior.
-mpmd-merge-inferred-fragments
Merges inferred fragments with user defined fragments.
Merges inferred with user-defined or other inferred fragments. This pass is
useful to clean-up/simplify the module and can be useful after other
compiler passes that introduce inferred fragments, while
-mpmd-transfer-aware-merge
above is more invasive and should be used for
optimization purposes only.
When clone_inferred_fragments=true
, then this merging pass allows for
certain fragments to be cloned. In particular, if we encounter a pair of
fragments f1 and f2 such that:
- f2 uses f1, and
- f1 is inferred, pure, and sufficiently simple (single non-return op and single result), then we merge a clone of f1 into f2, i.e., f1 itself (and other users) remain independent of f2. It may be undesirable to merge inferred producer fragments without cloning, because it can create unnecessary dependencies between fragments. E.g.,
%inferred = frag m1 { return stablehlo.const … }
%frag1 = frag m1 (%inferred, …)
%frag2 = frag m1 (%inferred, …)
~>
%inferred_frag1 = frag m1 (…) { … return const_m1, … }
%frag2 = frag m2 (inferred_frag1, …)
So frag2 now depends on inferred_frag1 and we create a dependency.
However, sometimes we do want to merge in place, e.g., when the inferred fragment has collectives inside.
Options
-clone-inferred-fragments : Whether to clone inferred fragments. Chains of clonable fragments are merged one-by-one into their consumers and recursively.
-merge-any-consumer : Whether to merge with any consumer or only the closest consumer.
-merge-sideways : Whether to merge with the next fragment in the same mesh (neighbor), even if not a consumer.
-mpmd-merge-transfers
Merges sets of transfers that share the same producer and consumer fragments.
Merges sets of transfers of the same payload type which share the same producer and consumer fragments. The payload values of these transfers have less elements than a given threshold, are not sharded and do not live in pinned_host.
Merging a set of transfers means: concatenating the transferred values at producer site and splitting them at consumer site.
-mpmd-merge-user-fragments-into-scheduling-units
Merge user based fragments pre pipeline scheduling passes.
Merges pairs of user defined fragments to be used together with pipeline scheduling passes.
-mpmd-remove-transfer-cycles
Removes device-only transfer cycles from the program, avoiding unnecessary transfers.
Removes transfer cycles.
E.g. in symbols:
x1 = transfer(x0) : m0 -> m1 x2 = transfer(x1) : m1 -> m2 x3 = transfer(x2) : m2 -> m3 x0_1 = transfer(x3) : m3 -> m0 x1_1 = transfer(x0_1) : m0 -> m1
~~>
x1 = transfer(x0) : m0 -> m1 x2 = transfer(x1) : m1 -> m2 x3 = transfer(x2) : m2 -> m3 x0_1 = x0 x1_1 = x1
i.e. we then we break the cycle by using the existing values, removing the unnecessary transfers.
Note that this could increase memory overhead, since transferring the data
away and back again means that there's a period where the data isn't on the
device. Thus, we only do this if the cycle only contains device-to-device
transfers, e.g. since a device -> host -> device
cycle could be for memory
purposes.
This doesn't use the MLIR Canonicalizer, because that doesn't guarantee that everything is canonicalized, and also it's more expensive to apply.
-mpmd-rule-based-merge
Merges fragments based on user-defined rules.
Merges fragments based on a specified list of rules, each specifying a list of source fragment to merge (by their fragment info) and the target info to label the merged fragment.
Options
-rules : A list of fragment merge rules, each with a list of source fragment infos and a target fragment info.
-mpmd-sink-negligible-ops-into-call-op
Sinks negligible ops into call-ops (i.e., the called function).
Sinks (negligible) ops into called functions: if there is an op with zero operands and a single result which is used as a specific operand of all call ops to the same function, then we sink it into those call ops, i.e., we clone it into the called function and replace all uses of the respective argument with the clone. Sunken ops are removed from the caller function and unused arguments of the callee (and operands of respective call ops) removed. Note that this can potentially duplicate computation across many microbatches, when using call ops for microbatching. Though, this computation is most likely negligible as it takes no operands.
-mpmd-split-and-prioritize-transfer-independent-computations
Splits backward fragments based on transferred results.
Splits a fragment into two fragments, so that we can start computation early. I.e. we split the fragment into two fragments A -> B, where A does not rely on any transfer result, and is maximally large, and B relies on transfer results.
-mpmd-split-bwd-fragments
Splits backward fragments based on transferred results.
Splits backwards fragments so that any computation that does not flow into transferred results becomes a fragment of its own. The original fragment will return some residual values that will be passed in as extra operands to the split-out fragments.
This split allows us to transfer results into other meshes earlier. One canonical use of this optimization will be splitting the activation gradient computation in back-propagation from the parameter gradient computation. Note also that due to the need to potentially thread through some residual values in the new fragments, memory pressure will increase.
-mpmd-uniquify-function-inputs-outputs
Uniquifies any value returned multiple times or any block argument directly returned by the function.
If a function returns the same value multiple times, creates multiple
versions for that value, by creating a fragment assigned to that value's
mesh which returns the value multiple times. After this pass, each return
operand is unique. This is important to ensure that the respective results
are allocated in different buffers, as in the following jax.jit
example:
def f(x):
y = x + x
return y, y
z1, z2 = f(5)
z1 += 1
print(z1) ~~> 6
print(z2) ~~> 5
Similarly, if a function returns a block argument, this pass creates an identity fragment for that block argument, guaranteeing that values are passed by value to the function, not by reference.
-mpmd-unroll-for-loops
Fully unrolls mpmd.for
loops.
Creates a pass that completely unrolls mpmd.for
ops, while
attaching an unroll_counter attribute to each unrolled op.
Requires: the unroll factor to be equal to the number of iterations.
-mpmd-verify-stage-merging
Verifies that merging of fragments assigned to stages succeeded.
Verifies that fragments with stage assignment have been correctly merged. This means that it's not possible to have in the module any two equivalent fragments in terms of assignment and counters.
Two fragments are equivalent in terms of assignment and counters iff a. they are assigned to the same mesh, b. they are assigned to the same stage, c. they have the same transpose count, and d. either both have the same call counter or one of them doesn't have a call counter defined (i.e., an undefined call counter matches any call counter).
This is needed to guarantee to the user that any computation assigned to the same stage is executed contiguously.