-mpmd-delay-inferred-fragments

Delays inferred fragments to be executed as late as possible.

Moves inferred fragments to be executed as late as possible, i.e., right before their first consumer.

-mpmd-delay-transfers-from-cpu

Delays transfers from CPU to be executed as late as possible.

Moves cpu-to-device transfer ops right before their first consumer. This means postponing the allocation of memory for the transferred buffers, which can be beneficial for HBM usage.

-mpmd-lower-to-fragment-calls

Lowers MPMD fragments to fragment calls and functions.

Replaces all fragments with fragment calls.

This pass creates a function for every group of fragments that have an identical body and mesh shape, with the name of the first encountered fragment in the group as the symbol name, and adds it to the symbol table.

The body of each function is extracted from the body of the first encountered fragment in the respective group, and the mesh shape is taken from the topology by the mesh name of that fragment. If the fragment has argument attributes about input-output aliasing, they will be assigned to the argument attributes of the lowered function.

Since functions must have unique names, this pass appends an index to the name of all but the first function with the same original name, i.e., the ith function with name "some_name" for i > 0 will have the name "some_name_i".

Options

-verbose-logging : Whether to enable verbose logging

-mpmd-mark-aliasing-and-donation

Marks each fragment with aliasing or donation information.

Sets an arg_attrs attribute to Fragment ops when any of their inputs can be aliased with an output or donated. Each input that can be aliased will have a tf.aliasing_output attribute. Otherwise, a jax.buffer_donor = true attribute. For example, {arg_attrs = \[{tf.aliasing_output = 0 : i32}, {jax.buffer_donor = true}, {}\]} shows the first input can be aliased with output 0, the second input can be donated so that XLA will find an aliased output, and the third input can't be aliased or donated.

-mpmd-mark-fragment-reserved-memory

Mark each fragment with the amount of memory that needs to be reserved for compilation.

Assigns an reserved_hbm_bytes attribute to each fragment which will tell XLA how many bytes to keep around while compiling each fragment. By keeping track of live tensors on a mesh, XLA will know of actual minimum memory usage at the time of execution, and we can prevent it from applying optimizations in the executable that would increase memory usage beyond the device capacity. NOTE: this pass assumes that fragments are executed in program order.

-mpmd-mark-input-output-with-layouts

Propagates layouts from func args/results to fragment args/results.

Propagates mhlo.layout_mode attributes from program inputs to fragments that are consumers of program input, and propagate mhlo.layout_mode attributes from program outputs to fragments that are output producers. If a program argument is returned then mhlo.layout_mode is propagated to/from program results. Inputs, outputs and fragment arguments/results that are connected with a transfer op are always overwritten with the DEFAULT layout, because transfers only support default layouts.

If a program argument is set to AUTO layout and is used in multiple fragments, then we set it to be the DEFAULT layout to setup consistent layout across fragments.

If program output, and also is a fragment result, is set to AUTO layout and used in other fragments as an input, we also set it to be the DEFAULT layout to setup consistent layout across fragments.

-mpmd-mark-offloaded-input-output

Marks offloaded input and output values so the compiler knows they are in host memory.

Marks fragment args and results with attributes to identify which values live in host memory, so that this information can be used by XLA. Also marks the entrypoint func args and results so that Pathways can use this information.

-mpmd-sink-create-token-into-fragments

_Sinks stablehlo.createtoken ops into fragment bodies.

The MPMD dialect doesn't support cross-fragment token passing yet. So instead, this pass duplicates the tokens and sinks them into each fragment, severing the token-mediated sequencing between fragments.

After this pass, no token ever crosses a fragment boundary.

-mpmd-validate-no-backward-deps

Validates no backward dependencies exist in forward-only programs.

Checks that every data dependency between fragment calls goes from a lexicographically earlier mesh name to a later one. A backward dependency creates pipeline bubbles.

By default, this pass emits a warning for each such dependency. If fail-on-backward-deps is set, it emits an error and fails.

This pass only applies to forward-only programs, i.e., programs where all fragments have a transpose count of 0. Programs with backward fragments (transpose count != 0) are skipped.

Options

-fail-on-backward-deps : Whether to emit an error (and fail) instead of a warning.

-mpmd-validate-no-inferred-fragments

Validates that no inferred fragments exist.

Checks that all inferred fragments have been merged.

By default, this pass emits an error if any inferred fragment remains. If fail-on-inferred-fragments is set to false, it emits a warning instead.

Options

-fail-on-inferred-fragments : Whether to emit an error (and fail) instead of a warning.

-mpmd-validate-no-param-transfers

Validates that no parameter tensors are transferred across meshes.

Checks that TransferOps do not transfer tensors whose JAX location info matches a configurable pattern (default: "params['transformer"). Such transfers typically indicate that model parameters are being erroneously moved across meshes.

By default, this pass emits a warning. Set fail-on-param-transfers=true to emit an error and fail instead.

Options

-fail-on-param-transfers : Whether to emit an error (and fail) instead of a warning (default: false).
-param-pattern           : Pattern to match against the location info of transferred tensors.

-mpmd-validate-no-reshards

Validates that no reshard-only fragments exist.

A reshard-only fragment is a fragment that contains only a mpmd.return of one of its arguments. These fragments usually indicate an unexpected reshard.

By default, this pass emits a warning for each such fragment. If fail-on-reshard-only-fragments is set, it emits an error and fails.

Options

-fail-on-reshard-only-fragments : Whether to emit an error (and fail) instead of a warning.