Modelo de dados
Os programas StableHLO são cálculos sobre tensores
(matrizes n-dimensionais), que, no modelo atual, são implementados usando a
classe Tensor
. A classe de armazenamento subjacente de um objeto Tensor
, detail::Buffer
, armazena o mlir::ShapedType
do tensor junto com um objeto mlir::HeapAsmResourceBlob
que representa um blob mutável de dados de tensor organizados como uma matriz de bytes contígua em ordem principal para secundária.
Os objetos detail::Buffer
são contados por referência para simplificar o gerenciamento de memória.
Elementos individuais de um tensor são representados usando a classe Element
, que usa uma união discriminada contendo um dos valores APInt
, APFloat
ou pair<APFloat,APFloat>
para armazenamento. O último é usado para armazenar elementos
com tipos complexos.
A Tensor
tem as seguintes APIs para interagir com os elementos individuais:
Element Tensor::get(llvm::ArrayRef<int64_t> index)
: para extrair um elemento de tensor individual em um índice multidimensionalindex
como um objetoElement
.void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
: para atualizar um objetoElement
element
em um tensor no índice multidimensionalindex
.
Como o intérprete funciona
A função de entrada para o intérprete é
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
que faz o seguinte:
- Rastreia os argumentos SSA de
func
e os valores deTensor
de ambiente de execução associados, fornecidos emargs
, usando um mapa de tabela de símbolos, M. - Para cada operação em
func
, na ordem SSACFG:- Invoca
eval
na operação. Para cada operando SSA da operação, extraia o valor de ambiente de execução de M a ser fornecido como um argumento para a invocaçãoeval
. - Rastreia os resultados SSA da operação e o valor avaliado em M.
- Invoca
O nível de operação eval
mencionado em (2) é responsável por implementar a semântica de execução da operação. Veja a seguir um exemplo de stablehlo::AddOp
:
No exemplo, elementos individuais dos tensores lhs
e rhs
são extraídos em pares como objetos Element
que são então adicionados. O resultado da adição,
um objeto Element
, é armazenado no tensor result
final.
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
No geral, o design do intérprete é otimizado para legibilidade de
implementações de funções eval
para operações individuais, porque precisa
servir como uma implementação de referência para o StableHLO. Por exemplo, em vez de
definir eval
como uma função de modelo e parametrizar-a com tipos de elementos,
encapsulamos detalhes sobre como diferentes tipos de elementos são processados em
Element::operator+
etc., simplificando a implementação de eval
.
Como usar o intérprete para dobras constantes
Podemos usar o mecanismo intérprete para dobrar operações com valores de operando constantes. O snippet de código a seguir demonstra uma ideia da implementação
para dobrar stablehlo::AddOp
com operandos digitados de ponto flutuante:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(element.getValue().cast<FloatAttr>().getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
No momento, não estamos trabalhando ativamente para integrar o intérprete em
dobrações constantes, porque não estamos planejando implementar uma pasta para o StableHLO.
No entanto, no futuro, planejamos usar o intérprete para dobrar
constante na MHLO.Nesse ponto, vamos melhorar a ergonomia do snippet de código
acima. Por exemplo, podemos ter uma função auxiliar que empacota operandos constantes em
objetos Tensor
e descompacta resultados de Tensor
em OpFoldResult
.
Como testar o intérprete de StableHLO
O intérprete toma como entradas (A) um programa StableHLO e (B) valores de dados a serem alimentados ao programa e gera valores de dados de saída, que são comparados aos valores de dados esperados fornecidos pelo usuário. Os valores de dados (B) são
codificados no próprio programa usando operações stablehlo.constant
. O
intérprete avalia o programa de entrada. As saídas da operação em teste
são verificadas com verificações (por exemplo, check.expect_eq
, check.expect_almost_eq
), conforme
mostrado abaixo. check.expect_eq
e check.expect_eq_const
verificam a igualdade
bit a bit para qualquer tipo com suporte, e check.expect_almost_eq
e
check.expect_almost_eq_const
verificam a igualdade próxima dentro de uma tolerância,
explicada na diretriz de teste (G6), para pontos flutuantes e tipos complexos.
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
Um utilitário de teste stablehlo-translate --interpret
(código)
é responsável por analisar o programa, interpretando cada função, incluindo as
operações que compõem a função. Temos um pacote de testes dedicado, que consiste em vários testes que exercem vários comportamentos de tempo de execução para cada operação do StableHLO. Os testes podem ser encontrados aqui (por exemplo, interpret_*.mlir).
Diretrizes de teste
(G1) Precisamos testar todos os tipos compatíveis em cada operação?
Podemos usar uma combinação das seguintes regras para decidir:
Ao implementar uma operação, se houver um código na função
eval
correspondente para processar um tipo específico, é fundamental ter testes para abranger esse tipo. Por exemplo, para a operaçãoadd
, há um código exclusivo para processar tipos inteiros, booleanos, de ponto flutuante e complexos. Portanto, precisamos de um teste para cada categoria de tipos.Se um conjunto de tipos for processado de maneira uniforme na função
eval
correspondente, um único teste para todos eles será suficiente. Por exemplo, para a operaçãoadd
, todas as variantes de tipos inteiros (si4
,u4
,si8
,u8
e assim por diante) são processadas da mesma forma usando APIsllvm::APInt
. Portanto, podemos pular a adição de testes para cada uma dessas variantes e adicionar um único teste representativo. Para evitar ambiguidade na seleção do representante, precisamos usar as seguintes diretrizes:- Se todos os tipos, tratados de maneira uniforme, tiverem o mesmo tipo primitivo (ou seja, se todos forem inteiros, de ponto flutuante ou complexos), escolha aquele com largura máxima de bits.
- Se todos os tipos, processados de maneira uniforme, tiverem uma combinação de tipos primitivos, escolha aquele com o seguinte tipo primitivo, em ordem decrescente de preferência: inteiro, ponto flutuante, booleano, complexo.
(G2) Como decidimos o número de testes necessários para cobrir o comportamento de uma operação?
O objetivo é abordar de maneira abrangente a lógica do intérprete da operação (ou seja, todos os casos isolados da implementação) com um número mínimo de testes. Para garantir a manutenção, é importante minimizar o número de testes. Quanto menos testes tivermos, mais fácil será revisá-los e garantir que eles cubram a operação de maneira abrangente. Como resultado, esperamos que a maioria das operações mais simples tenha apenas um teste. Se, por algum bom motivo, não for possível fazer uma cobertura abrangente, você poderá parar em >= 90%. Isso será decidido caso a caso durante a análise da solicitação de envio.
(G3) Que tal adicionar testes à infraestrutura do intérprete?
A infraestrutura do intérprete é bastante simples e pode ser adicionada à nossa base de confiança. A única parte não trivial é como vários tipos são empacotados e descompactados
no armazenamento do intérprete subjacente. Como discutido em (G1), testaremos
apenas os tipos de operação que são tratados de maneira diferente. Por
isso, é possível que o código de empacotamento/descompactação, correspondente a diferentes
variantes de tipos de números inteiros/de ponto flutuante, não seja totalmente coberto durante
os testes. Para garantir uma cobertura completa, podemos escolher uma operação como constant
, que
oferece suporte a todos os tipos de elemento do StableHLO e gravar testes abrangentes.
(G4) Se a implementação de uma operação depende de outras operações, é necessário criar testes para elas?
Não. Por exemplo, a implementação de batch_norm_grad
pode ser baseada em
divide
, subtract
, multiply
e outros. Devemos evitar testar as últimas
operações ao testar a primeira.
(G5) Vamos criar testes para exercitar os comportamentos definidos / indefinidos pela implementação?
Não devemos criar testes que pratiquem os comportamentos definidos pela implementação ou indefinidos da operação. Testes com comportamentos definidos pela implementação demonstram um comportamento local do intérprete, que não deve ser generalizado. Testes com comportamento indefinido não contribuem para o entendimento do comportamento da operação.
(G6) Ao criar testes para tipos de ponto flutuante, com que precisão o resultado esperado precisa ser especificado nas verificações?
Para operações básicas (adição, subtração, multiplicação, divisão e
quadrado), uma implementação que segue a especificação IEEE precisa fornecer um
resultado arredondado com até 0,5 ULP do resultado matematicamente exato. Dito isso, podemos
imaginar com segurança que o resultado esperado dessas operações seja de, no máximo,
um ULP de diferença. No entanto, isso pode não funcionar para funções transcendentais (sine
, cosine
etc.) para as quais as garantias de precisão são definidas pela implementação (justificativa).
A implementação atual usa um valor de tolerância de "tamanho único" de 0,0001. O exemplo a seguir demonstra a tolerância acima em ação.
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
Essa é apenas a primeira etapa para testar a precisão numérica das operações do StableHLO. No momento, essa é uma área não especificada da especificação StableHLO, e há trabalho contínuo para descobrir #1156 (link em inglês) com base na nossa experiência de uso do StableHLO na prática e no feedback das partes interessadas. À medida que isso avançar, atualizaremos a infraestrutura conforme necessário.
(G7) Alguma informação sobre o estilo de programação dos testes?
- Certifique-se de usar o nome real das entradas/saídas em vez de usar os valores SSA como padrão (por exemplo, %0, %1 etc.).
- Confira se os testes usam o formato formatado, se houver.
(G8) Devemos incluir o exemplo já fornecido na especificação? Sim (para a integridade do teste).