'mpmd' Dialect

Operations

mpmd.assign (mpmd::AssignOp)

Assign operation

Syntax:

operation ::= `mpmd.assign` attr-dict $tensor `:` functional-type(operands, results)

Assigns a local tensor to a mesh as fully replicated within that mesh.

This is a temporary op that is introduced when lowering jax ops, to move from local types to mesh types. These ops will be eliminated during import, when the inputs and results of the func op become mesh tensors.

The mesh name of the result type should correspond to a mesh in the topology, and its global type should be identical to the operand type.

The origin of the assign op is the origin of mesh, e.g. named_computation, mesh inference, etc.

Traits: AlwaysSpeculatableImplTrait, HasParent<::mlir::func::FuncOp, ForOp>

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
origin::mlir::StringAttrstring attribute

Operands:

Operand Description
tensor tensor of any type values

Results:

Result Description
result mesh tensor type

mpmd.broadcast (mpmd::BroadcastOp)

Broadcast operation

Syntax:

operation ::= `mpmd.broadcast` attr-dict $tensor `:` type($tensor)

Allows for a tensor to be transferred (or replicated) in any mesh where it's used. Whenever transferred, the origin of the transfer is the current location of the operand.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
tensor tensor of any type values

Results:

Result Description
result tensor of any type values

mpmd.call (mpmd::CallOp)

MPMD specific call function

Syntax:

operation ::= `mpmd.call` $callee `(` $tensors `)` attr-dict `:` functional-type(operands, results)

A function call operation. Useful to wrap the body of loops in function declarations to reduce code size, for example.

Interfaces: ArgAndResultAttrsOpInterface, CallOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
callee::mlir::FlatSymbolRefAttrflat symbol reference attribute

Operands:

Operand Description
tensors variadic of tensor of any type values or mesh tensor type

Results:

Result Description
«unnamed» variadic of tensor of any type values or mesh tensor type

mpmd.for (mpmd::ForOp)

For operator

Returns the result of executing a body function for a fixed number of iterations, with the iteration index available in the body.

An optional unroll factor, that must divide the number of iterations, can be specified to unroll the body of the op by that factor, i.e. for unroll factor N, the body is replicated to create N copies and the number of iterations is reduced by a factor of 1/N. Each copy except the first uses the results of the previous copy instead of the block arguments, and the iteration index is multiplied by the unroll factor and incremented after every copy.

A for operator can accept and return any types, but the TypeID of these must be the same -- e.g. all tensor types or all MPMD mesh types etc. This allows us to use the op at various levels, sharing implementation and transformations.

Traits: HLO_PairwiseSameOperandAndResultType, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Interfaces: ConditionallySpeculatable, LoopLikeOpInterface, OpAsmOpInterface, ShardableDataFlowOpInterface

Attributes:

AttributeMLIR TypeDescription
iterations::mlir::IntegerAttr32-bit unsigned integer attribute
unroll_factor::mlir::IntegerAttr32-bit unsigned integer attribute

Operands:

Operand Description
tensors variadic of any type

Results:

Result Description
results variadic of any type

mpmd.fragment (mpmd::FragmentOp)

Fragment operation

Assigns a computation, i.e. a block of operations, to a specific mesh in an MPMD topology, that is intended to be executed as an individual SPMD program fragment.

The fragment takes and returns only mesh tensors that are assigned to the same mesh as the fragment.

The mesh name of the fragment should correspond to a mesh in the topology.

The fragment includes a list of origins, i.e., metadata with information re the original named_computations that formed this fragment, and a staged_id defined iff it is a user defined fragment, i.e., it has a non-empty list of origins. The optional in_shardings specifies the sharding of the block arguments of a fragment, which correspond to the operands. The optional out_shardings specifies the shardings of the results.

The fragment's region shouldn't have any free variables, and the type of each block arguments and returned values in the region is the global tensor type of the corresponding mesh tensor.

Traits: HasParent<::mlir::func::FuncOp, ForOp>, IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Interfaces: ConditionallySpeculatable, ShardableDataFlowOpInterface

Attributes:

AttributeMLIR TypeDescription
origin::mlir::ArrayAttrarray of origin infos
mesh_name::mlir::StringAttrstring attribute
stage_id::mlir::IntegerAttr64-bit signless integer attribute
in_shardings::mlir::sdy::TensorShardingPerValueAttrTensor sharding per operand/result of an op
out_shardings::mlir::sdy::TensorShardingPerValueAttrTensor sharding per operand/result of an op

Operands:

Operand Description
inputs variadic of mesh tensor type

Results:

Result Description
results variadic of mesh tensor type

mpmd.fragment_call (mpmd::FragmentCallOp)

Fragment call operation

Represents a call to a function that holds an MPMD fragment body, i.e. a computation assigned to a specific mesh in an MPMD topology, that is intended to be executed as an individual SPMD program fragment.

The mesh name of the fragment should correspond to a mesh in the topology of the enclosing function, and that mesh shape should match that of the callee.

The origin specifies the user named computations that contributed to this fragment call e.g. through merging.

The function input and result types of the callee must be the local tensor types of the corresponding mesh tensors of this op's operands and results respectively.

Example:

%2 = mpmd.fragment_call<mesh="m1", origin=[]> @my_fragment(%0, %1) :
  (mesh_tensor<...>, mesh_tensor<...>) -> mesh_tensor<...>

Traits: HasParent<::mlir::func::FuncOp>, MemRefsNormalizable

Interfaces: ArgAndResultAttrsOpInterface, CallOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
origin::mlir::ArrayAttrarray of origin infos
mesh_name::mlir::StringAttrstring attribute
callee::mlir::FlatSymbolRefAttrflat symbol reference attribute

Operands:

Operand Description
tensors variadic of mesh tensor type

Results:

Result Description
«unnamed» variadic of mesh tensor type

mpmd.named_computation (mpmd::NamedComputationOp)

Named scope operation

Groups a computation, i.e. a block of operations, and gives it a name and a transpose count via the UserOrigin attribute. This NamedComputation can be used to assign a mesh to the computation in MPMD or for optimizations.

The transpose count (default=0) denotes whether the named computation has been produced by a certain number of JAX AD transpose transformations.

The op's region shouldn't have any free variables, and the type of each block arguments and returned values in the region must be the same as the type of the inputs and the return type of the op.

Traits: IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Interfaces: ConditionallySpeculatable

Attributes:

AttributeMLIR TypeDescription
origin::mlir::mpmd::UserOriginAttrOrigin of user-specified computation.

Operands:

Operand Description
tensors variadic of ranked tensor of 4/6/8/16/32/64-bit float or bool or 2/4/8/16/32/64-bit integer or complex type with 32/64-bit float elements or per-tensor integer quantized values or token

Results:

Result Description
results variadic of ranked tensor of 4/6/8/16/32/64-bit float or bool or 2/4/8/16/32/64-bit integer or complex type with 32/64-bit float elements or per-tensor integer quantized values or token

mpmd.named_tensor (mpmd::NamedTensorOp)

Assign a tensor to a mesh

Syntax:

operation ::= `mpmd.named_tensor` $tensor `name````=```$name attr-dict `:` type($result)

An identity op that associates the result of the tensor with a given name. This NamedTensor can be used to assign a mesh to the tensor in MPMD.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute

Operands:

Operand Description
tensor tensor of any type values

Results:

Result Description
result tensor of any type values

mpmd.reduce (mpmd::ReduceOp)

Cross-mesh reduce operation

Syntax:

operation ::= `mpmd.reduce` `` $reduction attr-dict $tensors `:` functional-type(operands, results)

Allows for a tensor to be reduced across different meshes, and then broadcast to wherever it needs to be used.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
reduction::mlir::mpmd::ReductionAttrDenotes a reduction.

Operands:

Operand Description
tensors variadic of tensor of any type values

Results:

Result Description
result tensor of any type values

mpmd.return (mpmd::ReturnOp)

The mpmd.return operation terminates the regions attached to mpmd region-based ops. It is variadic: it takes as arguments a list of values whose types can be any (but of the same kind, e.g. AnyTensor) and therefore can be reused at various levels of the MPMD IR stack.

Syntax:

operation ::= `mpmd.return` attr-dict $results (`:` type($results)^)?

Traits: AlwaysSpeculatableImplTrait, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
results variadic of any type

mpmd.transfer (mpmd::TransferOp)

Transfer operation

Syntax:

operation ::= `mpmd.transfer` attr-dict $tensor `:` functional-type(operands, results)

Transfers a distributed tensor from one mesh to another.

The mesh names of the operand and result types should correspond to meshes in the topology, and their global types should be identical.

Traits: AlwaysSpeculatableImplTrait, HasParent<::mlir::func::FuncOp>

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ShardingRuleOpInterface

Effects: MemoryEffects::Effect{}

Operands:

Operand Description
tensor mesh tensor type

Results:

Result Description
result mesh tensor type

mpmd.unassign (mpmd::UnassignOp)

Unassign operation

Syntax:

operation ::= `mpmd.unassign` attr-dict $tensor `:` functional-type(operands, results)

Unassigns a fully replicated tensor from a mesh.

This is a temporary op that is introduced when lowering jax ops, to move from local types to mesh types. These ops will be eliminated during import, when the inputs and results of the func op become mesh tensors.

The mesh name of the operand type should correspond to a mesh in the topology, and its global type should be identical to the result type.

Traits: AlwaysSpeculatableImplTrait, HasParent<::mlir::func::FuncOp, ForOp>, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
origin::mlir::StringAttrstring attribute

Operands:

Operand Description
tensor mesh tensor type

Results:

Result Description
result tensor of any type values

Attributes

MeshWithOriginsAttr

Mesh with its origins.

Syntax:

#mpmd.mesh_with_origins<
  ::llvm::StringRef,   # mesh_name
  ::llvm::ArrayRef<OriginAttr>   # origins
>

Parameters:

Parameter C++ type Description
mesh_name ::llvm::StringRef mesh_name
origins ::llvm::ArrayRef<OriginAttr> origins

MeshesWithOriginsAttr

A list of meshes with their origins.

Syntax:

#mpmd.meshes_with_origins<
  ::llvm::ArrayRef<MeshWithOriginsAttr>   # value
>

Parameters:

Parameter C++ type Description
value ::llvm::ArrayRef<MeshWithOriginsAttr>

NamedMeshAttr

A pair with a name and a Mesh.

Syntax:

#mpmd.named_mesh<
  ::llvm::StringRef,   # name
  sdy::MeshAttr   # mesh
>

Parameters:

Parameter C++ type Description
name ::llvm::StringRef name
mesh sdy::MeshAttr mesh

OriginAttr

Origin of mesh assignment.

Syntax:

#mpmd.origin<
  ::llvm::StringRef   # origin_label
>

The origin of a mesh assignment.

origin_label is a human-readable label for the origin. It is intended to be used for debugging purposes.

Parameters:

Parameter C++ type Description
origin_label ::llvm::StringRef origin_label

ReductionAttr

Denotes a reduction.

Syntax:

#mpmd.reduction<
  ::mlir::mpmd::ReductionType   # reduction_type
>

Parameters:

Parameter C++ type Description
reduction_type ::mlir::mpmd::ReductionType an enum of type ReductionType

TopologyAttr

Topology of named meshes.

Syntax:

#mpmd.topology<
  ::llvm::ArrayRef<NamedMeshAttr>   # meshes
>

Parameters:

Parameter C++ type Description
meshes ::llvm::ArrayRef<NamedMeshAttr> topology meshes

UserOriginAttr

Origin of user-specified computation.

Syntax:

#mpmd.user_origin<
  ::mlir::StringAttr,   # userName
  int64_t   # transposeCount
>

Parameters:

Parameter C++ type Description
userName ::mlir::StringAttr
transposeCount int64_t

Types

MeshTensorType

Mesh tensor type

Assigns a RankedTensorType to a specific SPMD mesh in the program's MPMD topology of meshes. The type holds an optional sharding that specifies how the tensor is sharded w.r.t to the SPMD mesh. If the sharding is not present the tensor is fully replicated.

Parameters:

Parameter C++ type Description
mesh_name ::llvm::StringRef mesh name
ranked_tensor_type ::mlir::RankedTensorType ranked tensor type
sharding ::mlir::sdy::TensorShardingAttr
memory_kind ::mlir::StringAttr

Enums

ReductionType

Reduction type attribute

Cases:

Symbol Value String
kNone 0 none
kAdd 1 add
kMax 2 max
kMin 3 min
kMul 4 mul
kOr 5 or
kAnd 6 and

EdgeNodeType

Edge node type enum

Cases:

Symbol Value String
OPERAND 0 operand
RESULT 1 result

PropagationDirection

Propagation direction enum

Cases:

Symbol Value String
NONE 0 NONE
FORWARD 1 FORWARD
BACKWARD 2 BACKWARD
BOTH 3 BOTH