API Compiler

Contexto

Assumimos que os leitores estão familiarizados com pelo menos os conceitos básicos de representação de fragmentação, que descreve como a fragmentação de um tensor pode ser expressa no Shardy. Este documento mostra como as representações de divisão podem ser usadas em um programa, por exemplo, para anexar uma divisão a um tensor específico do programa.

A propagação de fragmentação é o processo de decisão sobre um fragmentação para cada tensor em um programa, considerando as restrições de fragmentação para um subconjunto dos tensores. A API do compilador do Shardy expõe várias maneiras de influenciar/controlar a propagação do sharding. Além disso, permite que os usuários insiram cálculos divididos manualmente nos programas.

Objetivo

Este documento descreve o design desses componentes de API no Shardy e explica o comportamento e os invariantes. Embora essa API seja usada para controlar a propagação de fragmentação, este documento NÃO vai discutir nada sobre o comportamento da propagação nem como ela é projetada.

Visão geral

  • Fração de entrada/saída: anexa um fracionamento a uma entrada ou saída da função principal para indicar como o tensor de entrada/saída precisa ser fracionado quando fornecido para/retornado pela função.

  • Restrição de fragmentação: anexe uma fragmentação a um tensor intermediário (por exemplo, o resultado de uma matmul) para indicar como esse tensor ou um subconjunto de usos dele deve ser fragmentado.

  • Grupo de fragmentação: agrupe vários tensores por um ID para indicar que eles precisam ser fragmentados da mesma maneira.

  • Cálculo manual: inclui uma subcomputação que é particionada manualmente usando um subconjunto de eixos de malha, em que os fragmentos ao longo desses eixos manuais são especificados para todas as entradas e saídas, e dentro da subcomputação, os tipos de tensor são locais em relação a esses fragmentos.

Design detalhado

Fragmentações de entrada/saída

Permite que os usuários especifiquem um fragmentação para as entradas e saídas da função principal.

No MLIR, os atributos podem ser anexados a argumentos e resultados de função. Dessa forma, os usuários podem anexar atributos de fragmentação à função.

Exemplo:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

Restrição de fragmentação

Permite que os usuários associem um sharding a um tensor intermediário no programa, o que informa ao particionador como esse tensor ou um subconjunto de usos dele deve ser dividido.

Esta é uma operação do MLIR que recebe o tensor como entrada e tem um atributo de fragmentação anexado a ele. A operação pode:

  • Não tem usos (solto), o que significa que o particionamento anexado é como o tensor em si precisa ser particionado.
  • Ter usos: significa que o particionamento anexado é como os usos da operação de restrição de particionamento precisam ser particionados, enquanto outros usos do tensor de entrada podem ter um particionamento diferente. Se o tensor de entrada não tiver outros usos, o comportamento será o mesmo do caso sem usos. A propagação vai determinar o sharding do tensor e refazer o sharding, se necessário.

Ele pode ter divisões de dimensão abertas, o que significa que o operando pode ser dividido ainda mais nos eixos disponíveis.

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

Grupo de fragmentação

Nos casos em que não há dependências de dados ou dependências fortes de dados entre dois ou mais tensores, enquanto os usuários sabem que esses tensores precisam ser particionados da mesma maneira ou de maneira semelhante, a API Shardy oferece uma maneira de especificar essa relação. Isso dá aos usuários a liberdade de especificar explicitamente que os tensores precisam ser particionados como um ao outro.

Para isso, apresentamos o conceito de grupos de fragmentos, em que cada grupo contém qualquer número de instruções associadas ao mesmo ID de grupo de fragmentos. Os grupos de fragmentação garantem que os fragmentos no mesmo grupo sejam iguais.

Por exemplo, em um programa de usuário hipotético, como mostrado abaixo, queremos fragmentar a saída do programa exatamente como a entrada do programa, enquanto não há dependências de dados entre os dois.

Se executarmos esse programa, a propagação de fragmentação não poderá inferir o fragmento de tensores %1 e %2, e eles serão replicados. No entanto, ao anexar um atributo shard_group que diz que a entrada %0 e a saída %2 estão no mesmo shard_group, permitimos que o sharding @mesh_xy, [{"x"},{"y"}]> seja propagado da entrada %0 para a saída %2 e, por sua vez, para o restante do gráfico, que é transmitido pela constante %1 aqui. Podemos atribuir um valor a um grupo com a operação sdy.sharding_group.

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

Neste exemplo simples acima, poderíamos especificar explicitamente o mesmo particionamento na saída como a entrada, o que teria o mesmo efeito, já que já sabíamos qual fragmento queríamos atribuir à entrada com antecedência, mas em casos mais realistas, usamos o fragmento para manter o particionamento de vários tensores sincronizados sem necessariamente conhecer o particionamento de nenhum deles, enquanto Shardy cuida do restante e encontra o melhor particionamento para atribuir a eles.

Cálculo manual

Os usuários podem querer controlar explicitamente como partes da computação são particionadas e quais coletivos são usados. Por exemplo, alguns usuários querem aplicar o matmul coletivo manualmente (da API de front-end) em vez de adiar para o compilador. Fornecemos uma API de cálculo manual que permite isso.

Esta é a operação MLIR com uma única região para a subcomputação manual. Os usuários especificam partições de entrada/saída para essa subcomputação usando um subconjunto (possivelmente todos) dos eixos da malha. A subcomputação seria local/manual em relação aos eixos de malha especificados (também conhecidos como eixos manuais) e globais/não particionados em relação aos eixos não especificados (também conhecidos como eixos livres). A subcomputação pode ser dividida ao longo dos eixos livres durante a propagação da mesma forma que a computação fora dessa operação.

Exemplo:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Invariantes

  1. Todos os in_shardings, out_shardings e manual_axes precisam se referir à mesma rede. manual_axes é classificado em relação à malha.

  2. O manual_axes precisa ser usado explicitamente em todos os particionamentos de entrada/saída, ou seja, para cada particionamento, todos os eixos manuais precisam particionar uma dimensão ou ser explicitamente replicados.

  3. Se um eixo livre (qualquer eixo de malha que não esteja em manual_axes) existir em um dos fragmentos de entrada/saída, ele precisa ser menor que qualquer eixo manual no mesmo fragmento de dimensão. No exemplo acima, um fragmento de dimensão {"model", "data"} seria inválido.

  4. A região/corpo da computação é a computação local (por exemplo, incluindo coletivos especificados pelo usuário). Ele precisa ser local em relação ao fragmentação de entrada/saída ao longo dos eixos manuais (consulte a observação acima).

Como aninhar cálculos manuais

É possível aninhar várias computações manuais umas dentro das outras, desde que cada uma opere no próprio conjunto exclusivo de eixos manuais.