Diseño de intérprete

Modelo de datos

Los programas estables son cálculos sobre tensores (arrays de n dimensiones), que, en el modelo actual, se implementan mediante la clase Tensor. La clase de almacenamiento subyacente para un objeto Tensor, detail::Buffer, almacena el mlir::ShapedType del tensor junto con un objeto mlir::HeapAsmResourceBlob, que representa un BLOB mutable de datos del tensor dispuesto como un array de bytes contiguo en orden mayor a menor. Los objetos detail::Buffer tienen un recuento de referencias para simplificar la administración de la memoria.

Los elementos individuales de un tensor se representan con la clase Element, que usa una unión discriminada que contiene APInt, APFloat o pair<APFloat,APFloat> para el almacenamiento. La última se usa para almacenar elementos con tipos complejos.

Tensor tiene las siguientes APIs para interactuar con sus elementos individuales:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): Para extraer un elemento tensor individual en el índice multidimensional index como objeto Element.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: Permite actualizar un objeto Element element en un tensor en el índice multidimensional index.

Cómo funciona el intérprete

La función de entrada para el intérprete es

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

que hace lo siguiente:

  1. Realiza un seguimiento de los argumentos de SSA de func y sus valores Tensor del entorno de ejecución asociados, proporcionados en args, mediante un mapa de tabla de símbolos, M.
  2. Para cada operación dentro de func, en el orden de SSACFG:
    • Invoca eval en la operación. Para cada operando SSA de la op, extrae su valor de entorno de ejecución de M a fin de que se proporcione como un argumento para la invocación eval.
    • Realiza un seguimiento de los resultados de SSA de la operación y el valor evaluado en M.

El eval de nivel de operación mencionado en (2) es responsable de implementar la semántica de ejecución de la operación. A continuación, se muestra un ejemplo de stablehlo::AddOp. En el ejemplo, los elementos individuales de los tensores lhs y rhs se extraen en pares como objetos Element que luego se agregan. El resultado de la adición, un objeto Element, se almacena en el tensor final result.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

En general, el diseño del intérprete está optimizado para facilitar la lectura de las implementaciones de funciones eval en operaciones individuales, ya que está diseñado como una implementación de referencia para StableHLO. Por ejemplo, en lugar de definir eval como una función de plantilla y parametrizarla con tipos de elementos, encapsulamos detalles sobre cómo se manejan los diferentes tipos de elementos en Element::operator+, etc., lo que simplifica la implementación de eval.

Cómo usar el intérprete para plegado constante

Podemos usar el mecanismo de intérprete para plegar operaciones con valores de operandos constantes. En el siguiente fragmento de código, se muestra una idea de la implementación para plegar stablehlo::AddOp con operandos de tipo de punto flotante:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

Por el momento, no estamos trabajando activamente en la integración del intérprete en el plegado constante porque no planeamos implementar la carpeta para StableHLO. Sin embargo, en el futuro, planeamos aprovechar el intérprete para el plegado constante en MHLO. En ese momento, mejoraremos la ergonomía del fragmento de código anterior (p. ej., podríamos tener una función auxiliar que empaquete los operandos constantes en objetos Tensor y descomprima los resultados de Tensor en OpFoldResult).

Prueba el intérprete StableHLO

El intérprete toma como entradas (A) un programa StableHLO y (B) valores de datos que se enviarán al programa, y genera valores de datos de salida que se comparan con los valores de datos esperados que proporciona el usuario. Los valores de datos (B) están hard-coded en el programa mediante operaciones stablehlo.constant. El intérprete evalúa el programa de entrada. Los resultados de la operación que se están probando se verifican mediante verificaciones (p.ej., check.expect_eq y check.expect_almost_eq), como se muestra a continuación. check.expect_eq y check.expect_eq_const verifican la igualdad a nivel de bits para cualquier tipo admitido, y check.expect_almost_eq y check.expect_almost_eq_const comprueban la igualdad cercana dentro de una tolerancia, según se explica en el lineamiento de prueba (G6), para el punto flotante y los tipos complejos.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

Una utilidad de prueba stablehlo-translate --interpret (código) es responsable de analizar el programa y de interpretar cada función, incluidas las operaciones que la conforman. Tenemos un paquete de pruebas dedicado, que consta de varias pruebas que ejecutan varios comportamientos de tiempo de ejecución para cada operación StableHLO. Las pruebas se pueden encontrar aquí (p.ej., interpret_*.mlir).

Lineamientos para las pruebas

(G1) ¿Necesitamos probar todos los tipos compatibles para cada operación?

Podemos usar una combinación de las siguientes reglas para decidir:

  1. Durante la implementación de una op, si existe código en la función eval correspondiente para controlar un tipo en particular, es fundamental tener pruebas que cubran ese tipo. A modo de ejemplo, en la op add hay un código exclusivo para controlar tipos enteros, booleanos, de punto flotante y complejos, por lo que necesitamos una prueba para cada categoría de tipos.

  2. Si un conjunto de tipos se maneja de manera uniforme en la función eval correspondiente, una sola prueba para todos esos tipos debería ser suficiente. A modo de ejemplo, para la operación add, todas las variantes de tipos de números enteros (si4, u4, si8, u8, etc.) se manejan de la misma manera con las APIs de llvm::APInt y, por lo tanto, podemos omitir la adición de pruebas para cada una de esas variantes y, en su lugar, agregar una sola prueba representativa. Para evitar ambigüedades al seleccionar al representante, debemos usar los siguientes lineamientos:

    • Si todos los tipos, manejados de manera uniforme, tienen el mismo tipo primitivo (es decir, si todos son de número entero, de punto flotante o tipos complejos), entonces elige el que tenga el ancho de bits máximo.
    • Si todos los tipos, controlados de manera uniforme, tienen una combinación de tipos primitivos, elige el que tenga el siguiente tipo primitivo, en orden descendente de preferencia: número entero, punto flotante, booleano, complejo.

(G2) ¿Cómo decidimos la cantidad de pruebas necesarias para cubrir el comportamiento de una operación?

El objetivo es cubrir de manera exhaustiva la lógica del intérprete de la operación (es decir, todos los casos límite de la implementación) con una cantidad mínima de pruebas. Minimizar la cantidad de pruebas es importante para el mantenimiento. Cuantas menos pruebas tengamos, más fácil será revisarlas y asegurarse de que cubran la operación de manera integral. Como resultado, esperamos que la mayoría de las operaciones más simples acaben teniendo solo una prueba. Si, por alguna buena razón, la cobertura integral no es práctica, entonces está bien establecer el valor en >= 90%. Esto se decidirá caso por caso durante la revisión de la solicitud de extracción.

(G3) ¿Quieres agregar pruebas para la infraestructura de intérprete?

La infraestructura del intérprete es mayormente sencilla y se puede agregar a nuestra base de confianza. La única parte no trivial es cómo se empaquetan y desempaquetan varios tipos desde el almacenamiento de intérprete subyacente. Como se explicó en (G1), probaremos solo los tipos de operaciones que se manejan de manera diferente. De esta manera, es posible que el código de empaquetado o desempaquetado, que corresponde a diferentes variantes de tipos de número entero y punto flotante, no se cubra completamente durante las pruebas. Para garantizar una cobertura completa, podemos elegir una op, como constant, que admita todos los tipos de elementos StableHLO y escribir pruebas exhaustivas.

(G4) Si la implementación de una op depende de otras operaciones, ¿debemos escribir pruebas para la última?

No. Por ejemplo, la implementación de batch_norm_grad se puede basar en divide, subtract, multiply y otros. Debemos evitar probar las últimas operaciones mientras pruebas las primeras.

(G5) ¿Debemos escribir pruebas para ejercer los comportamientos definidos por la implementación o no definidos?

No debemos escribir pruebas que ejemplifiquen los comportamientos definidos por la implementación o indefinidos de la operación. Las pruebas que tienen comportamientos definidos por la implementación muestran un comportamiento local del intérprete que no se debe generalizar. Las pruebas que tienen un comportamiento indefinido no contribuyen a que se comprenda el comportamiento de la operación.

(G6) Mientras se escriben pruebas para los tipos de punto flotante, ¿con qué precisión se debe especificar el resultado esperado en las verificaciones?

Para las operaciones básicas (suma, resta, multiplicación, división y cuadrado), se espera que una implementación después de la especificación IEEE proporcione un resultado redondeado dentro de 0.5 ULP del resultado matemáticamente exacto. Dicho esto, podemos imaginar con seguridad que el resultado esperado de estas operaciones se ubicará a, como máximo, a 1 ULP. Sin embargo, es posible que esto no funcione para funciones trascendentales (sine, cosine, etc.) para las que las garantías de precisión están definidas por la implementación (racional).

La implementación actual utiliza un valor de tolerancia “único para todos” de 0.0001. En el siguiente ejemplo, se muestra en acción la tolerancia anterior.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

Este es solo el primer paso para probar la exactitud numérica de las operaciones StableHLO. Por el momento, esta es un área poco especificada de la especificación de StableHLO, y aún estamos trabajando para resolverlo, #1156 según nuestra experiencia con StableHLO en la práctica y en los comentarios de las partes interesadas. A medida que el proceso avance, actualizaremos la infraestructura según corresponda.

(G7) ¿Algo sobre el estilo de programación de las pruebas?

  1. Asegúrate de usar el nombre real de las entradas y salidas en lugar de usar los valores SSA de forma predeterminada (p.ej., %0, %1, etcétera).
  2. Asegúrate de que las pruebas usen un formato impreso (si existe).

(G8) ¿Debemos incluir el ejemplo que ya se proporcionó en la especificación? Sí (para lograr la integridad de la prueba).