El dialecto Shardy (SDY) define una representación de fragmentación de tensor basada en eje y componentes de API adicionales para adjuntar fragmentaciones a los tensores.
Operaciones
sdy.constant
(sdy::ConstantOp)
Operación constante
Produce un tensor output
a partir de una constante value
.
Consulta: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Ejemplo:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Características: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efectos: MemoryEffects::Effect{}
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
value | ::mlir::ElementsAttr | atributo de vector o tensor constante |
Resultados:
Resultado | Descripción |
---|---|
output |
tensor de cualquier tipo de valores |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
Operación perimetral del flujo de datos.
Sintaxis:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
Un borde de flujo de datos de alguna operación X define un puente entre un conjunto de fuentes (cada una es un operando de X o un operando del terminador de bloque de X) y un conjunto de destinos (cada uno es un resultado de X o un argumento de bloque de X), de modo que todas las fuentes y los destinos se deben dividir de la misma manera.
Una operación puede tener varios bordes de flujo de datos que son ortogonales entre sí.
Por ejemplo:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
Esta operación while tiene n aristas de flujo de datos, y la arista de flujo de datos en el paso i está entre las fuentes x_i
, return_value_i
y los destinos y_i
, pred_arg_i
y body_arg_i
.
Un sdy.data_flow_edge
toma como entrada el destino raíz de un borde (puede ser cualquiera de los destinos, pero preferiblemente un resultado de operación en lugar de un argumento de bloque), que no debería tener ningún otro uso. Esta operación no es pura porque puede tomar una entrada que, en principio, no tenía ningún uso.
sdy.data_flow_edge
también contiene un fragmentación opcional para todos los destinos del borde, y esa fragmentación se debe actualizar en lugar de la fragmentación de los destinos (si se puede adjuntar) durante la propagación. Esto es útil cuando una operación tiene muchos bordes, ya que es mucho más eficiente hacer lo siguiente:
- propagarse a través de cada borde por separado.
- Actualizar la fragmentación de cada extremo por separado en lugar de todos los destinos a la vez
(p. ej., una operación tiene un solo
TensorShardingPerValueAttr
inmutable para la fragmentación de resultados) - Agrega cada borde a la lista de tareas por separado cuando cambie el fragmento de una fuente.
La propagación propagará los particionados entre todas las fuentes y los destinos de un sdy.data_flow_edge
como si fuera una operación normal con las fuentes como operandos y los destinos como resultados, y una identidad sdy.op_sharding_rule
. Eso significa que la propagación hacia adelante es de las fuentes a los destinos y la propagación hacia atrás es de los destinos a las fuentes.
No permitimos que una op SdyDialect
defina la entrada de un sdy.data_flow_edge
, por lo que podemos suponer que está definida por una op que tiene un atributo sdy.sharding
no registrado.
Características: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | División de tensores |
Operandos:
Operando | Descripción |
---|---|
input |
con forma de cualquier tipo de valores |
Resultados:
Resultado | Descripción |
---|---|
result |
con forma de cualquier tipo de valores |
sdy.manual_computation
(sdy::ManualComputationOp)
Operación de paralelismo multidispositivo con colectivos manuales
Sintaxis:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
Accede a una región escrita en términos de código local por dispositivo con colectivos explícitos, en los que las formas lógicas coinciden con las formas de búfer físico locales por dispositivo y los colectivos corresponden exactamente a la comunicación física entre dispositivos.
El cuerpo es local en relación con los ejes manuales. La propagación se realizará a través del cuerpo en cualquier eje libre (los que no están en la lista manual_axes).
Atributos: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
y SingleBlock
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Fragmentación de tensores por operando/resultado de una operación |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Fragmentación de tensores por operando/resultado de una operación |
manual_axes | ::mlir::sdy::ManualAxesAttr |
Operandos:
Operando | Descripción |
---|---|
tensors |
variádico de tensor clasificado de cualquier tipo de valores |
Resultados:
Resultado | Descripción |
---|---|
results |
variádico de tensor clasificado de cualquier tipo de valores |
sdy.mesh
(sdy::MeshOp)
Malla con nombre
Sintaxis:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Define una nueva malla con nombre. Todas las mallas de un módulo deben tener la misma cantidad de dispositivos (excepto las mallas con un solo device_id).
La malla es una operación Symbol
que aparece en el SymbolTable
del módulo y a la que puede hacer referencia su name
.
Rasgos: HasParent<ModuleOp>
Interfaces: Symbol
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
sym_name | ::mlir::StringAttr | atributo de cadena |
mesh | ::mlir::sdy::MeshAttr | Malla de ejes y una lista de dispositivos |
sdy.named_computation
(sdy::NamedComputationOp)
Operación de procesamiento con nombre
Sintaxis:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
Agrupa un cálculo, es decir, un bloque de operaciones, y le asigna un nombre. La propagación fluirá dentro o fuera de la región como si todo estuviera intercalado.
Esto se puede usar para controlar la propagación a través de instrucciones de llamada a otras
funciones. Cualquier usuario de Shardy debe escribir un pase de importación/exportación que
convierta sus operaciones de llamada en operaciones sdy.named_computation
, duplicando o copiando
el cuerpo de la función llamada en el cuerpo de named_computation
.
El tipo de cada argumento de bloque y los valores que se muestran en la región deben ser los mismos que el tipo de los operandos y el tipo de resultados de la operación.
Ejemplo:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Atributos: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
y SingleBlock
Interfaces: ConditionallySpeculatable
, ShardableDataFlowOpInterface
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
name | ::mlir::StringAttr | atributo de cadena |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Fragmentación de tensores por operando/resultado de una operación |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | División de tensores por operando o resultado de una operación |
Operandos:
Operando | Descripción |
---|---|
operands |
variadic de cualquier tipo |
Resultados:
Resultado | Descripción |
---|---|
"unnamed" | variadic de cualquier tipo |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
Operación de barrera de propagación
Sintaxis:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
Esta op funciona como una de identidad y da como resultado el mismo valor que tomó como entrada. Sin embargo, en términos de propagación, esto solo permitirá que la propagación fluya a través de ella en una dirección determinada.
Esto evita que las fragmentaciones se propaguen entre los usos del resultado de la operación de barrera y su operando.
FORWARD
significa que los particionados solo pueden fluir del operando al resultado.BACKWARD
significa que las fragmentaciones solo pueden fluir del resultado al operando.NONE
significa que no se puede propagar ningún fragmento a través de esta operación.- No se puede especificar
BOTH
, ya que esta operación sería redundante.
Atributos: AlwaysSpeculatableImplTrait
, Elementwise
y SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efectos: MemoryEffects::Effect{}
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | enum de dirección de propagación |
Operandos:
Operando | Descripción |
---|---|
input |
Tensor clasificado de cualquier tipo de valores |
Resultados:
Resultado | Descripción |
---|---|
result |
Tensor clasificado de cualquier tipo de valores |
sdy.reshard
(sdy::ReshardOp)
Reasigna un tensor a un particionado diferente
Sintaxis:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Vuelve a fragmentar el tensor de entrada con la fragmentación especificada, que es diferente de la fragmentación existente del tensor de entrada.
Tanto ShardingConstraintOp como ReshardOp adjuntan un particionado a un tensor. Su vida útil es la siguiente:
- Antes de la propagación del particionamiento, los usuarios agregan ShardingConstraintOp.
- La propagación de fragmentación consume ShardingConstraintOp. No hay ShardingConstraintOp en los resultados de la propagación del particionamiento. En su lugar, se puede agregar ReshardOp si es necesario.
- Un particionador convierte un ReshardOp en una operación colectiva (o una operación de identidad). No debe haber ReshardOp en los resultados del particionador.
// TODO(b/331680067). Agrega un patrón de canonicalización para quitar las operaciones redundantes.
Atributos: AlwaysSpeculatableImplTrait
, Elementwise
y SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efectos: MemoryEffects::Effect{}
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | División de tensores |
Operandos:
Operando | Descripción |
---|---|
input |
tensor de cualquier tipo |
Resultados:
Resultado | Descripción |
---|---|
result |
tensor de cualquier tipo de valores |
sdy.return
(sdy::ReturnOp)
La operación sdy.return
finaliza las regiones adjuntas a las operaciones sdy
basadas en regiones y a cualquier otra operación basada en regiones de Shardy. Es
variable: toma como argumentos una lista de valores cuyos tipos pueden ser cualquiera (pero
del mismo tipo, p.ej., AnyTensor
) y, por lo tanto, se puede volver a usar en varios
niveles de la pila IR de Shardy.
Sintaxis:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Rasgos: AlwaysSpeculatableImplTrait
y Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Efectos: MemoryEffects::Effect{}
Operandos:
Operando | Descripción |
---|---|
results |
variadic de cualquier tipo |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
Restringe un tensor al particionado especificado.
Sintaxis:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Conecta un fragmento a un tensor intermedio (p.ej., el resultado de un matmul) para indicar que así es como se debe dividir ese tensor, o un subconjunto de sus usos.
Si el particionamiento tiene dimensiones abiertas y ejes sin restricciones, significa que el tensor se puede particionar aún más a lo largo de las dimensiones abiertas.
Esta operación puede hacer lo siguiente:
- No tienen usos (vinculación), lo que significa que la fragmentación adjunta es la forma en que se debe fragmentar el tensor de entrada.
- Tener usos, lo que significa que el fragmentación adjunta es la forma en que se deben fragmentar los usos de la operación de restricción de fragmentación, mientras que otros usos del tensor de entrada pueden tener una fragmentación diferente (si el tensor de entrada no tiene otros usos, el comportamiento es el mismo que el caso sin usos).
Características: Elementwise
, SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | División de tensores |
Operandos:
Operando | Descripción |
---|---|
input |
tensor de cualquier tipo |
Resultados:
Resultado | Descripción |
---|---|
result |
tensor de cualquier tipo de valores |
sdy.sharding_group
(sdy::ShardingGroupOp)
Operación de grupo de fragmentación
Sintaxis:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
Esta operación proporciona una interfaz para asignar tensores a grupos de fragmentación (grupos de tensores que se aplicarán para tener fragmentaciones idénticas). Durante la propagación, en cuanto se fragmente un elemento del grupo, todos los demás miembros se fragmentarán exactamente de la misma manera. Esta operación toma el ID del grupo de argumentos y no muestra ningún resultado, pero modifica la representación interna del grupo de fragmentación para agregar el tensor de entrada al grupo con el ID dado.
Atributos:
Atributo | Tipo de MLIR | Descripción |
---|---|---|
group_id | ::mlir::IntegerAttr | Atributo de número entero de 64 bits sin signo |
Operandos:
Operando | Descripción |
---|---|
input |
Tensor clasificado de cualquier tipo de valores |
Atributos
AxisRefAttr
Referencia a un eje completo o a un subeje dividido
Sintaxis:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
nombre | ::llvm::StringRef |
nombre |
sub_axis_info | SubAxisInfoAttr |
DimMappingAttr
Lista de índices de factores para una dimensión
Todos los índices de factores deben estar en el rango [0, num_factors) y una lista vacía indica que se trata de una asignación nula (se analiza o imprime con *
), es decir, la dimensión no se asigna a ningún factor.
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
DimensionShardingAttr
División de dimensiones
Es una lista de nombres de ejes para dividir una dimensión de tensor de mayor a menor, un valor booleano que indica si la dimensión se puede dividir aún más y un número entero opcional que indica la prioridad de esta división de dimensión, que se respetará durante la propagación de la división. Las prioridades provienen de las anotaciones de fragmentación de usuarios, y un valor más bajo indica una prioridad más alta. Se asume la prioridad más alta cuando falta la prioridad en la anotación.
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
ejes | ::llvm::ArrayRef<AxisRefAttr> |
lista de referencias de eje |
is_closed | bool |
|
priority | std::optional<int64_t> |
ManualAxesAttr
Sintaxis:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
valor | ::llvm::ArrayRef<StringAttr> |
MeshAttr
Malla de ejes y una lista de dispositivos
Sintaxis:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
Una malla es una lista de ejes y una lista opcional de IDs de dispositivos que especifica el orden de los dispositivos.
Si la lista de ejes está vacía, la malla tiene un eje implícito sin nombre de tamaño 1. En este caso, si no se proporciona una lista de IDs de dispositivos, la lista implícita de IDs de dispositivos es [0]; si se proporciona una lista de IDs de dispositivos, debe contener un solo número entero de cualquier valor no negativo. A esto lo llamamos caso de fragmentación máxima.
Para todos los casos de fragmentación no máxima, si se especifica una lista de IDs de dispositivos, el producto de los tamaños de los ejes debe coincidir con la cantidad de dispositivos. Si no se especifica una lista de IDs de dispositivos, la lista implícita de IDs de dispositivos es iota(product(axes)). Para simplificar, tampoco permitimos especificar una lista de IDs de dispositivos que sea igual a iota(product(axes)); en este caso, no se debe especificar una lista de IDs de dispositivos.
Estos son algunos ejemplos de mallas:
- Una malla vacía representa una malla de marcadores de posición que se puede reemplazar durante la propagación: <[]>
- Una malla con un eje sin nombre y un ID de dispositivo explícito, que suele usarse para representar la fragmentación máxima: <[], device_ids=[3]>
- Una malla con dos ejes e IDs de dispositivos implícitos iota(6): <["a"=2, "b"=3]>
- Una malla con dos ejes y IDs de dispositivos explícitos que especifican el orden de los dispositivos: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
ejes | ::llvm::ArrayRef<MeshAxisAttr> |
|
device_ids | ::llvm::ArrayRef<int64_t> |
MeshAxisAttr
Eje con nombre en una malla
Sintaxis:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
nombre | ::llvm::StringRef |
nombre |
tamaño | int64_t |
OpShardingRuleAttr
Especifica cómo se puede particionar una operación.
Sintaxis:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
bool # is_custom_rule
>
Una regla de fragmentación especifica cómo se puede particionar una operación según varias propiedades de la operación: cualquier atributo, la forma de los operandos, la forma de los resultados, etcétera. Por ejemplo:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Ten en cuenta que permitimos factores con un tamaño de 1 a pesar de que no se pueden fragmentar, esto es principalmente para la integridad, ya que muchas operaciones, como las de punto, tienen un tamaño de una dimensión que se corresponde entre operandos y resultados.
is_custom_rule
describe si esta es una regla definida por un usuario para una operación stablehlo.custom_call
. El particionador no sabe cómo particionar estas operaciones, por lo que un usuario debe indicarle cómo hacerlo. Cuando se trata de una regla personalizada, esta siempre se conserva o nunca se quita. is_custom_rule
solo puede ser verdadero para las operaciones stablehlo.custom_call
.
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
|
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
is_custom_rule | bool |
SubAxisInfoAttr
Información sobre cómo se deriva este subeje a partir del eje completo
Sintaxis:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
Cuando se divide un eje completo en n subejes, el eje se modifica en [k_1,…,k_n], y el subeje enésimo se puede expresar como el producto de todos los tamaños del eje a su izquierda m=prod(k_1,...,k_(i-1))
(también conocido como tamaño previo) y el tamaño k_i. Por lo tanto, el atributo sub-axis-info contiene esos dos números y se representa de la siguiente manera: (m)k
para el tamaño previo m y el tamaño k.
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
pre_size | int64_t |
|
tamaño | int64_t |
TensorMappingAttr
Asignaciones de factores para cada dimensión de un tensor.
Sintaxis:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
TensorShardingAttr
División de tensores
Sintaxis:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
La fragmentación de tensores está vinculada a una malla específica y solo puede hacer referencia a los nombres de ejes de esa malla. La fragmentación de dimensiones nos indica para cada dimensión del tensor, a lo largo de los cuales se fragmenta de mayor a menor. Todos los demás ejes que no fragmentan una dimensión se replican de forma implícita o explícita (si aparecen en la lista de ejes replicados).
La malla a la que está vinculado este fragmentación se puede especificar con un nombre de símbolo, que hace referencia a un símbolo MeshOp
correspondiente, o un MeshAttr
intercalado.
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
Atribución de malla o atribución de referencia de símbolo de malla plana |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
|
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
lista de referencias de eje |
TensorShardingPerValueAttr
División de tensores por operando o resultado de una operación
Sintaxis:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
Parámetros:
Parámetro | Tipo de C++ | Descripción |
---|---|---|
fragmentación | ::llvm::ArrayRef<TensorShardingAttr> |
Enumeraciones
PropagationDirection
enum de dirección de propagación
Casos:
Símbolo | Valor | String |
---|---|---|
NINGUNO | 0 |
NINGUNO |
FORWARD | 1 |
FORWARD |
BACKWARD | 2 |
BACKWARD |
AMBOS | 3 |
AMBOS |