-mpmd-delay-inferred-fragments

Delays inferred fragments to be executed as late as possible.

Moves inferred fragments to be executed as late as possible, i.e., right before their first consumer.

-mpmd-delay-transfers-from-cpu

Delays transfers from CPU to be executed as late as possible.

Moves cpu-to-device transfer ops right before their first consumer. This means postponing the allocation of memory for the transferred buffers, which can be beneficial for HBM usage.

-mpmd-lower-to-fragment-calls

Lowers MPMD fragments to fragment calls and functions.

Replaces all fragments with fragment calls.

This pass creates a function for every group of fragments that have an identical body and mesh shape, with the name of the first encountered fragment in the group as the symbol name, and adds it to the symbol table.

The body of each function is extracted from the body of the first encountered fragment in the respective group, and the mesh shape is taken from the topology by the mesh name of that fragment. If the fragment has argument attributes about input-output aliasing, they will be assigned to the argument attributes of the lowered function.

Since functions must have unique names, this pass appends an index to the name of all but the first function with the same original name, i.e., the ith function with name "some_name" for i > 0 will have the name "some_name_i".

Options

-group-across-meshes : Whether to do more aggressive fragment grouping, across meshes. This may not be desirable for heterogeneous systems.
-verbose-logging     : Whether to enable verbose logging

-mpmd-mark-aliasing-and-donation

Marks each fragment with aliasing or donation information.

Sets an arg_attrs attribute to Fragment ops when any of their inputs can be aliased with an output or donated. Each input that can be aliased will have a tf.aliasing_output attribute. Otherwise, a jax.buffer_donor = true attribute. For example, {arg_attrs = \[{tf.aliasing_output = 0 : i32}, {jax.buffer_donor = true}, {}\]} shows the first input can be aliased with output 0, the second input can be donated so that XLA will find an aliased output, and the third input can't be aliased or donated.

-mpmd-mark-fragment-reserved-memory

Mark each fragment with the amount of memory that needs to be reserved for compilation.

Assigns an xla_tpu_user_reserved_hbm_bytes attribute to each fragment which will tell XLA how many bytes to keep around while compiling each fragment. By keeping track of live tensors on a mesh, XLA will know of actual minimum memory usage at the time of execution, and we can prevent it from applying optimizations in the executable that would increase memory usage beyond the device capacity. NOTE: this pass assumes that fragments are executed in program order.

-mpmd-mark-input-output-with-layouts

Propagates layouts from func args/results to fragment args/results.

Propagates mhlo.layout_mode attributes from program inputs to fragments that are consumers of program input, and propagate mhlo.layout_mode attributes from program outputs to fragments that are output producers. If a program argument is returned then mhlo.layout_mode is propagated to/from program results. See PropagateLayoutsForReturnedFuncArgs's comment for a detailed description of the propagation logic.

-mpmd-mark-offloaded-input-output

Marks offloaded input and output values so the compiler knows they are in host memory.

Marks fragment args and results with attributes to identify which values live in host memory, so that this information can be used by XLA. Also marks the entrypoint func args and results so that Pathways can use this information.