API de Compiler

Segundo plano

Suponemos que los lectores conocen al menos los aspectos básicos de la representación del particionamiento, que describe cómo se puede expresar el particionamiento de un tensor en Shardy. En este documento, se muestra cómo se pueden usar las representaciones de fragmentación en un programa, p.ej., para adjuntar un fragmento a un tensor específico del programa.

La propagación de fragmentación es el proceso de decidir un fragmento para cada tensor en un programa, dadas las restricciones de fragmentación para un subconjunto de los tensores. La API del compilador de Shardy expone varias formas de influir o controlar la propagación del fragmentación. Además, permite a los usuarios insertar cálculos divididos manualmente en sus programas.

Objetivo

En este documento, se describe el diseño de esos componentes de API en Shardy y se explica su comportamiento y sus invariantes. Ten en cuenta que, si bien esta API se usa para controlar la propagación del fragmento, en este documento NO se analizará nada sobre el comportamiento de la propagación ni cómo está diseñada.

Descripción general

  • Fragmentación de entrada/salida: Adjunta un fragmento a una entrada o salida de la función principal para indicar que así es como se debe fragmentar el tensor de entrada/salida cuando se le pasa a la función o se muestra desde ella.

  • Restricción de fragmentación: Adjunta un fragmento a un tensor intermedio (p.ej., el resultado de un matmul) para indicar que así es como se debe fragmentar ese tensor o un subconjunto de sus usos.

  • Grupo de fragmentación: Agrupa varios tensores por un ID para indicar que se deben fragmentar de la misma manera.

  • Cálculo manual: Encierra un subcálculo que se particiona de forma manual con un subconjunto de ejes de malla, en los que se especifican las particiones a lo largo de esos ejes manuales para todas las entradas y salidas, y dentro del subcálculo, los tipos de tensores son locales en relación con esas particiones.

Diseño detallado

División en fragmentos de entrada y salida

Permite a los usuarios especificar un fragmento para las entradas y salidas de la función principal.

En MLIR, los atributos se pueden adjuntar a los argumentos y resultados de la función y, por lo tanto, los usuarios pueden adjuntar atributos de fragmentación a la función de esta manera.

Por ejemplo:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

Restricción de fragmentación

Permite que los usuarios adjunten un fragmento a un tensor intermedio en su programa, lo que le indica al particionador que así es como se debe dividir ese tensor, o un subconjunto de sus usos.

Esta es una operación de MLIR que toma el tensor como entrada y tiene un atributo de fragmentación asociado. La operación puede hacer lo siguiente:

  • No tener usos (pendiente), lo que significa que el fragmento adjunto es la forma en que se debe fragmentar el tensor.
  • Tener usos, lo que significa que el fragmentación adjunta es la forma en que se deben fragmentar los usos de la operación de restricción de fragmentación, mientras que otros usos del tensor de entrada pueden tener una fragmentación diferente (si el tensor de entrada no tiene otros usos, el comportamiento es el mismo que el caso sin usos). La propagación determinará el fragmentación del tensor y lo volverá a fragmentar si es necesario.

Puede tener particiones de dimensiones abiertas, lo que significa que el operando se puede particionar aún más a lo largo de los ejes disponibles.

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

Grupo de fragmentación

En los casos en que no hay dependencias de datos o no hay dependencias de datos sólidas entre dos o más tensores, mientras que los usuarios saben que esos tensores deben particionarse de la misma manera o de manera similar, la API de Shardy ofrece una forma de especificar esta relación. Esto les da a los usuarios la libertad de especificar explícitamente que los tensores deben particionarse entre sí.

Para lograrlo, presentamos un concepto de grupos de fragmentos, en el que cada grupo contiene cualquier cantidad de instrucciones asociadas con el mismo ID de grupo de fragmentos. Los grupos de fragmentación aplican la fragmentación dentro del mismo grupo para que sea la misma.

Por ejemplo, en un programa de usuario hipotético como el que se muestra a continuación, queremos fragmentar el resultado del programa exactamente igual que la entrada del programa, mientras que no hay dependencias de datos entre ambos.

Si ejecutamos este programa, la propagación del fragmento no podrá inferir en el fragmento de los tensores %1 y %2, y terminarán replicándose. Sin embargo, si adjuntas un atributo shard_group que indica que la entrada %0 y la salida %2 están dentro del mismo shard_group, permitimos que el fragmento @mesh_xy, [{"x"},{"y"}]> se propague de la entrada %0 a la salida %2 y, a su vez, al resto del gráfico, que se transmite como constante %1 aquí. Podemos asignar un valor a un grupo con la operación sdy.sharding_group.

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

En este ejemplo simple anterior, como alternativa, podríamos haber especificado de forma explícita el mismo fragmento en la salida que en la entrada, lo que lograría el mismo efecto, ya que ya sabemos qué fragmento queremos asignar a la entrada con anticipación. Sin embargo, en casos más realistas, usamos el fragmento para mantener la fragmentación de varios tensores en sincronización sin conocer necesariamente la fragmentación de ninguno de ellos, mientras que Shardy se encargará del resto y encontrará la mejor fragmentación para asignarles.

Cálculo manual

Los usuarios pueden querer tener un control explícito sobre cómo se particionan las partes de su procesamiento y qué colectivos se usan. Por ejemplo, algunos usuarios desean aplicar matmul colectivo de forma manual (desde la API de frontend) en lugar de diferirlo al compilador. Proporcionamos una API de procesamiento manual que les permite hacer eso.

Esta es la operación de MLIR con una sola región para el subcálculo manual. Los usuarios especificarían los fragmentos de entrada y salida para esta subcomputación con un subconjunto (posiblemente todos) de los ejes de malla. El subcálculo sería local o manual en relación con los ejes de malla especificados (también conocidos como ejes manuales) y global o sin particionar en relación con los no especificados (también conocidos como ejes libres). La subcomputación se puede particionar aún más a lo largo de los ejes libres durante la propagación de la misma manera que se puede hacer con la computación fuera de esta operación.

Por ejemplo:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Invarianzas

  1. Todos los in_shardings, out_shardings y manual_axes deben hacer referencia a la misma malla. manual_axes se ordena en función de la malla.

  2. manual_axes se debe usar de forma explícita en todos los particionamientos de entrada y salida, es decir, para cada particionamiento, todos los ejes manuales deben particionar una dimensión o replicarse de forma explícita.

  3. Si existe un eje libre (cualquier eje de malla que no esté en manual_axes) en uno de los fragmentos de entrada o salida, debe ser menor que cualquier eje manual en el mismo fragmento de dimensión (en el ejemplo anterior, un fragmento de dimensión {"model", "data"} no sería válido).

  4. La región o el cuerpo del procesamiento es el procesamiento local (p.ej., incluye los colectivos especificados por el usuario). Debe ser local en relación con el particionamiento de entrada y salida a lo largo de los ejes manuales (consulta la nota anterior).

Anidación de cálculos manuales

Puedes anidar varios cálculos manuales entre sí, siempre que cada uno opere en su propio conjunto único de ejes manuales.