Dynamism in StableHLO

El estado actual del dinamismo se describe de forma más formal en la RFC de dinamismo. En esta página, se proporciona una descripción general de alto nivel de la RFC y se analizan las APIs y herramientas importantes para interactuar con programas dinámicos.

Terminología y descripción general de la asistencia de Dynamism

Primero, para abordar algunos términos que aparecerán en este documento, así como una breve introducción a su compatibilidad con StableHLO:

Dimensiones dinámicas

Las dimensiones dinámicas se refieren a cualquier dimensión cuyo tamaño es desconocido. En StableHLO, representamos las dimensiones dinámicas con ?, es decir, tensor<16x?xf32>.

Dinamismo limitado

El dinamismo limitado se refiere a una dimensión dinámica cuyo valor tiene un límite superior conocido. Generalmente, esto es útil para rellenar el tensor durante la ejecución. En StableHLO, representamos el dinamismo limitado con #stablehlo.bounds como una codificación de tensor, es decir, un tensor de rango 2 con una dimensión dinámica limitada en 16 y la otra sin límite que se puede representar como tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO puede representar el dinamismo limitado, pero la compatibilidad con el framework es limitada, se origina en TensorFlow y tiene cierta compatibilidad con PyTorch/XLA.

Dinamismo ilimitado

El dinamismo ilimitado, como su nombre lo indica, hace referencia a una dimensión dinámica sin un límite conocido de tamaño. Este tipo de dinamismo es muy común en StableHLO, con compatibilidad con JAX, PyTorch/XLA y TF, que se suele usar para exportar modelos con tamaño de lote o longitud de secuencia dinámicos.

En StableHLO, simplemente eludimos la codificación de límites para esta forma de dinamismo, es decir, tensor<?x?xf32>.

Polimorfismo de forma

El polimorfismo de forma es un término que heredamos de JAX.

Hay dos implicaciones clave para dar forma al polimorfismo:

  1. Todo el dinamismo del programa se remonta a sus argumentos de entrada.
  2. Todo el dinamismo se refiere solo a las formas del tensor, es decir, no depende de los datos.

Con estas dos reglas, una vez que se conocen las formas estáticas de un programa, podemos tomar un programa dinámico y definirlo completamente en un programa estático para la compilación (consulta "Pasos del compilador para definir mejor los programas dinámicos").

En general, el polimorfismo de forma usa un dinamismo ilimitado. Si las formas de argumentos conocidas pueden generar un programa completamente estático, no es necesario adivinar cómo limitar los valores.

Dinamismo dependiente de los datos

El dinamismo dependiente de los datos se refiere a los tamaños de dimensiones dinámicas que pertenecen a los datos dentro de un tensor. El ejemplo canónico es una función nonzeros que muestra los índices de todos los elementos que son 0 en un valor de tensor. La forma no se puede conocer sin evaluar los datos, pero a menudo se puede compilar con dinamismo limitado, lo que consume memoria adicional en el tamaño potencial del tensor de salida.

Muchas operaciones dinámicas dependientes de los datos se pueden modelar con un dinamismo limitado, en el que se especifica un límite superior para el tamaño de un tensor, y el hardware generalmente lo implementa a través del padding de tensores. En la actualidad, hay cierta compatibilidad con el dinamismo dependiente de los datos en PyTorch/XLA y TensorFlow, pero JAX no hace un seguimiento de las operaciones, lo que conduce al dinamismo dependiente de los datos.

Cómo exportar programas con dimensiones dinámicas

Consulta nuestros instructivos de StableHLO para obtener información sobre cómo exportar programas con tamaños de lotes dinámicos o longitudes de secuencias:

Pases del compilador para definir mejor los programas dinámicos

Quita la canalización del pase de dinamismo

Hay algunos pases útiles para definir mejor las formas, y todos están agrupados en una canalización de pases createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Pases individuales para definir mejor el dinamismo

De forma individual, los pases que suelen ser útiles para definir mejor las formas son los siguientes:

Consulta la documentación vinculada para obtener información y ejemplos actualizados.

Ejemplo: ¿Por qué es útil el dinamismo y cómo puedo usarlo?

El dinamismo tiene muchos usos. Aquí, nos centraremos principalmente en el caso de uso común del polimorfismo de forma: crear una representación flexible de modelo exportado, que se usa generalmente para representar el tamaño del lote dinámico o la longitud de la secuencia.

Modelo add_one estático

Para demostrar esto, usaremos el siguiente modelo add_one simple:

def add_one(x):
  return x + 1

Si se hace un seguimiento con un tensor<4xf32>, obtendremos el siguiente programa StableHLO:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Este modelo funcionará solo para los argumentos de entrada que tengan una forma tensor<4xf32>. Si alguna vez cambiamos el tamaño del lote o la longitud de la secuencia, deberíamos volver a rastrear el código fuente y volver a bajar a StableHLO, y no hay garantía de que aún tengamos acceso al código fuente.

Modelo dinámico add_one

Aquí es donde entra en juego el dinamismo polimórfico de la forma. En su lugar, JAX y PyTorch/XLA pueden emitir el modelo add_one con un IR válido de forma dinámica que transmitirá la constante para que coincida con la forma de entrada dinámica de la siguiente manera:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Esta representación del modelo es mucho más flexible y permite la especificación aplazada de valores como el tamaño del lote o la longitud de la secuencia. Este modelo se puede implementar en plataformas con compatibilidad con formas dinámicas (como AI Edge) o se puede definir mejor con los pases de dinamismo que se mencionan en esta documentación.

Cómo definir mejor el modelo dinámico

Por ejemplo, el siguiente orden de pases puede definir mejor este programa:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

De manera incremental, así es como se transforma el programa:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}