Operations
mpmd.assign
(mpmd::AssignOp)
Assign operation
Syntax:
operation ::= `mpmd.assign` attr-dict $tensor `:` functional-type(operands, results)
Assigns a local tensor to a mesh as fully replicated within that mesh.
This is a temporary op that is introduced when lowering jax ops, to move from local types to mesh types. These ops will be eliminated during import, when the inputs and results of the func op become mesh tensors.
The mesh name of the result type should correspond to a mesh in the topology, and its global type should be identical to the operand type.
The origin of the assign op is the origin of mesh, e.g. named_computation, mesh inference, etc.
Traits: AlwaysSpeculatableImplTrait
, HasParent<::mlir::func::FuncOp, ForOp>
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
origin | ::mlir::StringAttr | string attribute |
Operands:
Operand | Description |
---|---|
tensor |
tensor of any type values |
Results:
Result | Description |
---|---|
result |
mesh tensor type |
mpmd.broadcast
(mpmd::BroadcastOp)
Broadcast operation
Syntax:
operation ::= `mpmd.broadcast` attr-dict $tensor `:` type($tensor)
Allows for a tensor to be transferred (or replicated) in any mesh where it's used. Whenever transferred, the origin of the transfer is the current location of the operand.
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:
Operand | Description |
---|---|
tensor |
tensor of any type values |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
mpmd.call
(mpmd::CallOp)
MPMD specific call function
Syntax:
operation ::= `mpmd.call` $callee `(` $tensors `)` attr-dict `:` functional-type(operands, results)
A function call operation. Useful to wrap the body of loops in function declarations to reduce code size, for example.
Interfaces: ArgAndResultAttrsOpInterface
, CallOpInterface
, SymbolUserOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
callee | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Operands:
Operand | Description |
---|---|
tensors |
variadic of tensor of any type values or mesh tensor type |
Results:
Result | Description |
---|---|
«unnamed» | variadic of tensor of any type values or mesh tensor type |
mpmd.for
(mpmd::ForOp)
For operator
Returns the result of executing a body function for a fixed number of iterations, with the iteration index available in the body.
An optional unroll factor, that must divide the number of iterations, can be specified to unroll the body of the op by that factor, i.e. for unroll factor N, the body is replicated to create N copies and the number of iterations is reduced by a factor of 1/N. Each copy except the first uses the results of the previous copy instead of the block arguments, and the iteration index is multiplied by the unroll factor and incremented after every copy.
A for operator can accept and return any types, but the TypeID of these must be the same -- e.g. all tensor types or all MPMD mesh types etc. This allows us to use the op at various levels, sharing implementation and transformations.
Traits: HLO_PairwiseSameOperandAndResultType
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
, LoopLikeOpInterface
, OpAsmOpInterface
, ShardableDataFlowOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
iterations | ::mlir::IntegerAttr | 32-bit unsigned integer attribute |
unroll_factor | ::mlir::IntegerAttr | 32-bit unsigned integer attribute |
Operands:
Operand | Description |
---|---|
tensors |
variadic of any type |
Results:
Result | Description |
---|---|
results |
variadic of any type |
mpmd.fragment
(mpmd::FragmentOp)
Fragment operation
Assigns a computation, i.e. a block of operations, to a specific mesh in an MPMD topology, that is intended to be executed as an individual SPMD program fragment.
The fragment takes and returns only mesh tensors that are assigned to the same mesh as the fragment.
The mesh name of the fragment should correspond to a mesh in the topology.
The fragment includes a list of origins, i.e., metadata with information re the original named_computations that formed this fragment, and a staged_id defined iff it is a user defined fragment, i.e., it has a non-empty list of origins. The optional in_shardings specifies the sharding of the block arguments of a fragment, which correspond to the operands. The optional out_shardings specifies the shardings of the results.
The fragment's region shouldn't have any free variables, and the type of each block arguments and returned values in the region is the global tensor type of the corresponding mesh tensor.
Traits: HasParent<::mlir::func::FuncOp, ForOp>
, IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
, ShardableDataFlowOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
origin | ::mlir::ArrayAttr | array of origin infos |
mesh_name | ::mlir::StringAttr | string attribute |
stage_id | ::mlir::IntegerAttr | 64-bit signless integer attribute |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Tensor sharding per operand/result of an op |
Operands:
Operand | Description |
---|---|
inputs |
variadic of mesh tensor type |
Results:
Result | Description |
---|---|
results |
variadic of mesh tensor type |
mpmd.fragment_call
(mpmd::FragmentCallOp)
Fragment call operation
Represents a call to a function that holds an MPMD fragment body, i.e. a computation assigned to a specific mesh in an MPMD topology, that is intended to be executed as an individual SPMD program fragment.
The mesh name of the fragment should correspond to a mesh in the topology of the enclosing function, and that mesh shape should match that of the callee.
The origin specifies the user named computations that contributed to this fragment call e.g. through merging.
The function input and result types of the callee must be the local tensor types of the corresponding mesh tensors of this op's operands and results respectively.
Example:
%2 = mpmd.fragment_call<mesh="m1", origin=[]> @my_fragment(%0, %1) :
(mesh_tensor<...>, mesh_tensor<...>) -> mesh_tensor<...>
Traits: HasParent<::mlir::func::FuncOp>
, MemRefsNormalizable
Interfaces: ArgAndResultAttrsOpInterface
, CallOpInterface
, SymbolUserOpInterface
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
origin | ::mlir::ArrayAttr | array of origin infos |
mesh_name | ::mlir::StringAttr | string attribute |
callee | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Operands:
Operand | Description |
---|---|
tensors |
variadic of mesh tensor type |
Results:
Result | Description |
---|---|
«unnamed» | variadic of mesh tensor type |
mpmd.named_computation
(mpmd::NamedComputationOp)
Named scope operation
Groups a computation, i.e. a block of operations, and gives it a name and a transpose count via the UserOrigin attribute. This NamedComputation can be used to assign a mesh to the computation in MPMD or for optimizations.
The transpose count (default=0) denotes whether the named computation has been produced by a certain number of JAX AD transpose transformations.
The op's region shouldn't have any free variables, and the type of each block arguments and returned values in the region must be the same as the type of the inputs and the return type of the op.
Traits: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
origin | ::mlir::mpmd::UserOriginAttr | Origin of user-specified computation. |
Operands:
Operand | Description |
---|---|
tensors |
variadic of ranked tensor of 4/6/8/16/32/64-bit float or bool or 2/4/8/16/32/64-bit integer or complex type with 32/64-bit float elements or per-tensor integer quantized values or token |
Results:
Result | Description |
---|---|
results |
variadic of ranked tensor of 4/6/8/16/32/64-bit float or bool or 2/4/8/16/32/64-bit integer or complex type with 32/64-bit float elements or per-tensor integer quantized values or token |
mpmd.named_tensor
(mpmd::NamedTensorOp)
Assign a tensor to a mesh
Syntax:
operation ::= `mpmd.named_tensor` $tensor `name````=```$name attr-dict `:` type($result)
An identity op that associates the result of the tensor with a given name. This NamedTensor can be used to assign a mesh to the tensor in MPMD.
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
name | ::mlir::StringAttr | string attribute |
Operands:
Operand | Description |
---|---|
tensor |
tensor of any type values |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
mpmd.reduce
(mpmd::ReduceOp)
Cross-mesh reduce operation
Syntax:
operation ::= `mpmd.reduce` `` $reduction attr-dict $tensors `:` functional-type(operands, results)
Allows for a tensor to be reduced across different meshes, and then broadcast to wherever it needs to be used.
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
reduction | ::mlir::mpmd::ReductionAttr | Denotes a reduction. |
Operands:
Operand | Description |
---|---|
tensors |
variadic of tensor of any type values |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
mpmd.return
(mpmd::ReturnOp)
The mpmd.return
operation terminates the regions attached to mpmd
region-based ops. It is variadic: it takes as arguments a list of values
whose types can be any (but of the same kind, e.g. AnyTensor
) and
therefore can be reused at various levels of the MPMD IR stack.
Syntax:
operation ::= `mpmd.return` attr-dict $results (`:` type($results)^)?
Traits: AlwaysSpeculatableImplTrait
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:
Operand | Description |
---|---|
results |
variadic of any type |
mpmd.transfer
(mpmd::TransferOp)
Transfer operation
Syntax:
operation ::= `mpmd.transfer` attr-dict $tensor `:` functional-type(operands, results)
Transfers a distributed tensor from one mesh to another.
The mesh names of the operand and result types should correspond to meshes in the topology, and their global types should be identical.
Traits: AlwaysSpeculatableImplTrait
, HasParent<::mlir::func::FuncOp>
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ShardingRuleOpInterface
Effects: MemoryEffects::Effect{}
Operands:
Operand | Description |
---|---|
tensor |
mesh tensor type |
Results:
Result | Description |
---|---|
result |
mesh tensor type |
mpmd.unassign
(mpmd::UnassignOp)
Unassign operation
Syntax:
operation ::= `mpmd.unassign` attr-dict $tensor `:` functional-type(operands, results)
Unassigns a fully replicated tensor from a mesh.
This is a temporary op that is introduced when lowering jax ops, to move from local types to mesh types. These ops will be eliminated during import, when the inputs and results of the func op become mesh tensors.
The mesh name of the operand type should correspond to a mesh in the topology, and its global type should be identical to the result type.
Traits: AlwaysSpeculatableImplTrait
, HasParent<::mlir::func::FuncOp, ForOp>
, InferTensorType
Interfaces: ConditionallySpeculatable
, InferShapedTypeOpInterface
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:
Attribute | MLIR Type | Description |
---|---|---|
origin | ::mlir::StringAttr | string attribute |
Operands:
Operand | Description |
---|---|
tensor |
mesh tensor type |
Results:
Result | Description |
---|---|
result |
tensor of any type values |
Attributes
MeshWithOriginsAttr
Mesh with its origins.
Syntax:
#mpmd.mesh_with_origins<
::llvm::StringRef, # mesh_name
::llvm::ArrayRef<OriginAttr> # origins
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
mesh_name | ::llvm::StringRef |
mesh_name |
origins | ::llvm::ArrayRef<OriginAttr> |
origins |
MeshesWithOriginsAttr
A list of meshes with their origins.
Syntax:
#mpmd.meshes_with_origins<
::llvm::ArrayRef<MeshWithOriginsAttr> # value
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
value | ::llvm::ArrayRef<MeshWithOriginsAttr> |
NamedMeshAttr
A pair with a name and a Mesh.
Syntax:
#mpmd.named_mesh<
::llvm::StringRef, # name
sdy::MeshAttr # mesh
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
name | ::llvm::StringRef |
name |
mesh | sdy::MeshAttr |
mesh |
OriginAttr
Origin of mesh assignment.
Syntax:
#mpmd.origin<
::llvm::StringRef # origin_label
>
The origin of a mesh assignment.
origin_label
is a human-readable label for the origin.
It is intended to be used for debugging purposes.
Parameters:
Parameter | C++ type | Description |
---|---|---|
origin_label | ::llvm::StringRef |
origin_label |
ReductionAttr
Denotes a reduction.
Syntax:
#mpmd.reduction<
::mlir::mpmd::ReductionType # reduction_type
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
reduction_type | ::mlir::mpmd::ReductionType |
an enum of type ReductionType |
TopologyAttr
Topology of named meshes.
Syntax:
#mpmd.topology<
::llvm::ArrayRef<NamedMeshAttr> # meshes
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
meshes | ::llvm::ArrayRef<NamedMeshAttr> |
topology meshes |
UserOriginAttr
Origin of user-specified computation.
Syntax:
#mpmd.user_origin<
::mlir::StringAttr, # userName
int64_t # transposeCount
>
Parameters:
Parameter | C++ type | Description |
---|---|---|
userName | ::mlir::StringAttr |
|
transposeCount | int64_t |
Types
MeshTensorType
Mesh tensor type
Assigns a RankedTensorType to a specific SPMD mesh in the program's MPMD topology of meshes. The type holds an optional sharding that specifies how the tensor is sharded w.r.t to the SPMD mesh. If the sharding is not present the tensor is fully replicated.
Parameters:
Parameter | C++ type | Description |
---|---|---|
mesh_name | ::llvm::StringRef |
mesh name |
ranked_tensor_type | ::mlir::RankedTensorType |
ranked tensor type |
sharding | ::mlir::sdy::TensorShardingAttr |
|
memory_kind | ::mlir::StringAttr |
Enums
ReductionType
Reduction type attribute
Cases:
Symbol | Value | String |
---|---|---|
kNone | 0 |
none |
kAdd | 1 |
add |
kMax | 2 |
max |
kMin | 3 |
min |
kMul | 4 |
mul |
kOr | 5 |
or |
kAnd | 6 |
and |
EdgeNodeType
Edge node type enum
Cases:
Symbol | Value | String |
---|---|---|
OPERAND | 0 |
operand |
RESULT | 1 |
result |
PropagationDirection
Propagation direction enum
Cases:
Symbol | Value | String |
---|---|---|
NONE | 0 |
NONE |
FORWARD | 1 |
FORWARD |
BACKWARD | 2 |
BACKWARD |
BOTH | 3 |
BOTH |