Design do intérprete

Modelo de dados

Os programas StableHLO são cálculos sobre tensores (matrizes n-dimensionais), que, no modelo atual, são implementados usando a classe Tensor. A classe de armazenamento subjacente de um objeto Tensor, detail::Buffer, armazena o mlir::ShapedType do tensor junto com um objeto mlir::HeapAsmResourceBlob que representa um blob mutável de dados de tensor organizados como uma matriz de bytes contígua em ordem principal para secundária. Os objetos detail::Buffer são contados por referência para simplificar o gerenciamento de memória.

Elementos individuais de um tensor são representados usando a classe Element, que usa uma união discriminada contendo um dos valores APInt, APFloat ou pair<APFloat,APFloat> para armazenamento. O último é usado para armazenar elementos com tipos complexos.

A Tensor tem as seguintes APIs para interagir com os elementos individuais:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): para extrair um elemento de tensor individual em um índice multidimensional index como um objeto Element.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: para atualizar um objeto Element element em um tensor no índice multidimensional index.

Como o intérprete funciona

A função de entrada para o intérprete é

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

que faz o seguinte:

  1. Rastreia os argumentos SSA de func e os valores de Tensor de ambiente de execução associados, fornecidos em args, usando um mapa de tabela de símbolos, M.
  2. Para cada operação em func, na ordem SSACFG:
    • Invoca eval na operação. Para cada operando SSA da operação, extraia o valor de ambiente de execução de M a ser fornecido como um argumento para a invocação eval.
    • Rastreia os resultados SSA da operação e o valor avaliado em M.

O nível de operação eval mencionado em (2) é responsável por implementar a semântica de execução da operação. Veja a seguir um exemplo de stablehlo::AddOp: No exemplo, elementos individuais dos tensores lhs e rhs são extraídos em pares como objetos Element que são então adicionados. O resultado da adição, um objeto Element, é armazenado no tensor result final.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

No geral, o design do intérprete é otimizado para legibilidade de implementações de funções eval para operações individuais, porque precisa servir como uma implementação de referência para o StableHLO. Por exemplo, em vez de definir eval como uma função de modelo e parametrizar-a com tipos de elementos, encapsulamos detalhes sobre como diferentes tipos de elementos são processados em Element::operator+ etc., simplificando a implementação de eval.

Como usar o intérprete para dobras constantes

Podemos usar o mecanismo intérprete para dobrar operações com valores de operando constantes. O snippet de código a seguir demonstra uma ideia da implementação para dobrar stablehlo::AddOp com operandos digitados de ponto flutuante:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

No momento, não estamos trabalhando ativamente para integrar o intérprete em dobrações constantes, porque não estamos planejando implementar uma pasta para o StableHLO. No entanto, no futuro, planejamos usar o intérprete para dobrar constante na MHLO.Nesse ponto, vamos melhorar a ergonomia do snippet de código acima. Por exemplo, podemos ter uma função auxiliar que empacota operandos constantes em objetos Tensor e descompacta resultados de Tensor em OpFoldResult.

Como testar o intérprete de StableHLO

O intérprete toma como entradas (A) um programa StableHLO e (B) valores de dados a serem alimentados ao programa e gera valores de dados de saída, que são comparados aos valores de dados esperados fornecidos pelo usuário. Os valores de dados (B) são codificados no próprio programa usando operações stablehlo.constant. O intérprete avalia o programa de entrada. As saídas da operação em teste são verificadas com verificações (por exemplo, check.expect_eq, check.expect_almost_eq), conforme mostrado abaixo. check.expect_eq e check.expect_eq_const verificam a igualdade bit a bit para qualquer tipo com suporte, e check.expect_almost_eq e check.expect_almost_eq_const verificam a igualdade próxima dentro de uma tolerância, explicada na diretriz de teste (G6), para pontos flutuantes e tipos complexos.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

Um utilitário de teste stablehlo-translate --interpret (código) é responsável por analisar o programa, interpretando cada função, incluindo as operações que compõem a função. Temos um pacote de testes dedicado, que consiste em vários testes que exercem vários comportamentos de tempo de execução para cada operação do StableHLO. Os testes podem ser encontrados aqui (por exemplo, interpret_*.mlir).

Diretrizes de teste

(G1) Precisamos testar todos os tipos compatíveis em cada operação?

Podemos usar uma combinação das seguintes regras para decidir:

  1. Ao implementar uma operação, se houver um código na função eval correspondente para processar um tipo específico, é fundamental ter testes para abranger esse tipo. Por exemplo, para a operação add, há um código exclusivo para processar tipos inteiros, booleanos, de ponto flutuante e complexos. Portanto, precisamos de um teste para cada categoria de tipos.

  2. Se um conjunto de tipos for processado de maneira uniforme na função eval correspondente, um único teste para todos eles será suficiente. Por exemplo, para a operação add, todas as variantes de tipos inteiros (si4, u4, si8, u8 e assim por diante) são processadas da mesma forma usando APIs llvm::APInt. Portanto, podemos pular a adição de testes para cada uma dessas variantes e adicionar um único teste representativo. Para evitar ambiguidade na seleção do representante, precisamos usar as seguintes diretrizes:

    • Se todos os tipos, tratados de maneira uniforme, tiverem o mesmo tipo primitivo (ou seja, se todos forem inteiros, de ponto flutuante ou complexos), escolha aquele com largura máxima de bits.
    • Se todos os tipos, processados de maneira uniforme, tiverem uma combinação de tipos primitivos, escolha aquele com o seguinte tipo primitivo, em ordem decrescente de preferência: inteiro, ponto flutuante, booleano, complexo.

(G2) Como decidimos o número de testes necessários para cobrir o comportamento de uma operação?

O objetivo é abordar de maneira abrangente a lógica do intérprete da operação (ou seja, todos os casos isolados da implementação) com um número mínimo de testes. Para garantir a manutenção, é importante minimizar o número de testes. Quanto menos testes tivermos, mais fácil será revisá-los e garantir que eles cubram a operação de maneira abrangente. Como resultado, esperamos que a maioria das operações mais simples tenha apenas um teste. Se, por algum bom motivo, não for possível fazer uma cobertura abrangente, você poderá parar em >= 90%. Isso será decidido caso a caso durante a análise da solicitação de envio.

(G3) Que tal adicionar testes à infraestrutura do intérprete?

A infraestrutura do intérprete é bastante simples e pode ser adicionada à nossa base de confiança. A única parte não trivial é como vários tipos são empacotados e descompactados no armazenamento do intérprete subjacente. Como discutido em (G1), testaremos apenas os tipos de operação que são tratados de maneira diferente. Por isso, é possível que o código de empacotamento/descompactação, correspondente a diferentes variantes de tipos de números inteiros/de ponto flutuante, não seja totalmente coberto durante os testes. Para garantir uma cobertura completa, podemos escolher uma operação como constant, que oferece suporte a todos os tipos de elemento do StableHLO e gravar testes abrangentes.

(G4) Se a implementação de uma operação depende de outras operações, é necessário criar testes para elas?

Não. Por exemplo, a implementação de batch_norm_grad pode ser baseada em divide, subtract, multiply e outros. Devemos evitar testar as últimas operações ao testar a primeira.

(G5) Vamos criar testes para exercitar os comportamentos definidos / indefinidos pela implementação?

Não devemos criar testes que pratiquem os comportamentos definidos pela implementação ou indefinidos da operação. Testes com comportamentos definidos pela implementação demonstram um comportamento local do intérprete, que não deve ser generalizado. Testes com comportamento indefinido não contribuem para o entendimento do comportamento da operação.

(G6) Ao criar testes para tipos de ponto flutuante, com que precisão o resultado esperado precisa ser especificado nas verificações?

Para operações básicas (adição, subtração, multiplicação, divisão e quadrado), uma implementação que segue a especificação IEEE precisa fornecer um resultado arredondado com até 0,5 ULP do resultado matematicamente exato. Dito isso, podemos imaginar com segurança que o resultado esperado dessas operações seja de, no máximo, um ULP de diferença. No entanto, isso pode não funcionar para funções transcendentais (sine, cosine etc.) para as quais as garantias de precisão são definidas pela implementação (justificativa).

A implementação atual usa um valor de tolerância de "tamanho único" de 0,0001. O exemplo a seguir demonstra a tolerância acima em ação.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

Essa é apenas a primeira etapa para testar a precisão numérica das operações do StableHLO. No momento, essa é uma área não especificada da especificação StableHLO, e há trabalho contínuo para descobrir #1156 (link em inglês) com base na nossa experiência de uso do StableHLO na prática e no feedback das partes interessadas. À medida que isso avançar, atualizaremos a infraestrutura conforme necessário.

(G7) Alguma informação sobre o estilo de programação dos testes?

  1. Certifique-se de usar o nome real das entradas/saídas em vez de usar os valores SSA como padrão (por exemplo, %0, %1 etc.).
  2. Confira se os testes usam o formato formatado, se houver.

(G8) Devemos incluir o exemplo já fornecido na especificação? Sim (para a integridade do teste).