-mpmd-delay-inferred-fragments
Delays inferred fragments to be executed as late as possible.
Moves inferred fragments to be executed as late as possible, i.e., right before their first consumer.
-mpmd-delay-transfers-from-cpu
Delays transfers from CPU to be executed as late as possible.
Moves cpu-to-device transfer ops right before their first consumer. This means postponing the allocation of memory for the transferred buffers, which can be beneficial for HBM usage.
-mpmd-lower-to-fragment-calls
Lowers MPMD fragments to fragment calls and functions.
Replaces all fragments with fragment calls.
This pass creates a function for every group of fragments that have an identical body and mesh shape, with the name of the first encountered fragment in the group as the symbol name, and adds it to the symbol table.
The body of each function is extracted from the body of the first encountered fragment in the respective group, and the mesh shape is taken from the topology by the mesh name of that fragment. If the fragment has argument attributes about input-output aliasing, they will be assigned to the argument attributes of the lowered function.
Since functions must have unique names, this pass appends an index to the name of all but the first function with the same original name, i.e., the ith function with name "some_name" for i > 0 will have the name "some_name_i".
Options
-verbose-logging : Whether to enable verbose logging
-mpmd-mark-aliasing-and-donation
Marks each fragment with aliasing or donation information.
Sets an arg_attrs attribute to Fragment ops when any of their inputs
can be aliased with an output or donated. Each input that can be aliased
will have a tf.aliasing_output attribute. Otherwise, a
jax.buffer_donor = true attribute. For example, {arg_attrs =
\[{tf.aliasing_output = 0 : i32}, {jax.buffer_donor = true}, {}\]} shows
the first input can be aliased with output 0, the second input can be
donated so that XLA will find an aliased output, and the third input can't
be aliased or donated.
-mpmd-mark-fragment-reserved-memory
Mark each fragment with the amount of memory that needs to be reserved for compilation.
Assigns an reserved_hbm_bytes attribute to each fragment which will tell
XLA how many bytes to keep around while compiling each fragment. By keeping
track of live tensors on a mesh, XLA will know of actual minimum memory
usage at the time of execution, and we can prevent it from applying
optimizations in the executable that would increase memory usage beyond the
device capacity.
NOTE: this pass assumes that fragments are executed in program order.
-mpmd-mark-input-output-with-layouts
Propagates layouts from func args/results to fragment args/results.
Propagates mhlo.layout_mode attributes from program inputs to fragments
that are consumers of program input, and propagate mhlo.layout_mode
attributes from program outputs to fragments that are output producers.
If a program argument is returned then mhlo.layout_mode is propagated
to/from program results. Inputs, outputs and fragment arguments/results
that are connected with a transfer op are always overwritten with the
DEFAULT layout, because transfers only support default layouts.
If a program argument is set to AUTO layout and is used in multiple fragments, then we set it to be the DEFAULT layout to setup consistent layout across fragments.
If program output, and also is a fragment result, is set to AUTO layout and used in other fragments as an input, we also set it to be the DEFAULT layout to setup consistent layout across fragments.
-mpmd-mark-offloaded-input-output
Marks offloaded input and output values so the compiler knows they are in host memory.
Marks fragment args and results with attributes to identify which values live in host memory, so that this information can be used by XLA. Also marks the entrypoint func args and results so that Pathways can use this information.
-mpmd-sink-create-token-into-fragments
_Sinks stablehlo.createtoken ops into fragment bodies.
The MPMD dialect doesn't support cross-fragment token passing yet. So instead, this pass duplicates the tokens and sinks them into each fragment, severing the token-mediated sequencing between fragments.
After this pass, no token ever crosses a fragment boundary.
-mpmd-validate-no-backward-deps
Validates no backward dependencies exist in forward-only programs.
Checks that every data dependency between fragment calls goes from a lexicographically earlier mesh name to a later one. A backward dependency creates pipeline bubbles.
By default, this pass emits a warning for each such dependency. If
fail-on-backward-deps is set, it emits an error and fails.
This pass only applies to forward-only programs, i.e., programs where all fragments have a transpose count of 0. Programs with backward fragments (transpose count != 0) are skipped.
Options
-fail-on-backward-deps : Whether to emit an error (and fail) instead of a warning.
-mpmd-validate-no-inferred-fragments
Validates that no inferred fragments exist.
Checks that all inferred fragments have been merged.
By default, this pass emits an error if any inferred fragment remains. If
fail-on-inferred-fragments is set to false, it emits a warning instead.
Options
-fail-on-inferred-fragments : Whether to emit an error (and fail) instead of a warning.
-mpmd-validate-no-param-transfers
Validates that no parameter tensors are transferred across meshes.
Checks that TransferOps do not transfer tensors whose JAX location info matches a configurable pattern (default: "params['transformer"). Such transfers typically indicate that model parameters are being erroneously moved across meshes.
By default, this pass emits a warning. Set
fail-on-param-transfers=true to emit an error and fail instead.
Options
-fail-on-param-transfers : Whether to emit an error (and fail) instead of a warning (default: false).
-param-pattern : Pattern to match against the location info of transferred tensors.
-mpmd-validate-no-reshards
Validates that no reshard-only fragments exist.
A reshard-only fragment is a fragment that contains only a mpmd.return of
one of its arguments. These fragments usually indicate an unexpected reshard.
By default, this pass emits a warning for each such fragment. If
fail-on-reshard-only-fragments is set, it emits an error and fails.
Options
-fail-on-reshard-only-fragments : Whether to emit an error (and fail) instead of a warning.