El estado actual del dinamismo se explica de manera más formal en el RFC de dinamismo. En esta página, se proporciona una descripción general de alto nivel del RFC y se analizan las APIs y las herramientas importantes para interactuar con programas dinámicos.
Descripción general de la terminología y la asistencia de dinamismo
Primero, para abarcar algunos términos que aparecerán en este documento, así como una breve introducción a su compatibilidad en StableHLO, ten en cuenta lo siguiente:
Dimensiones dinámicas
Las dimensiones dinámicas se refieren a cualquier dimensión cuyo tamaño es desconocido.
En StableHLO, representamos las dimensiones dinámicas con ?, es decir, tensor<16x?xf32>.
Dinamismo limitado
El dinamismo acotado hace referencia a una dimensión dinámica cuyo valor tiene un límite superior conocido. En general, esto es útil para agregar padding al tensor durante la ejecución.
En StableHLO, representamos el dinamismo acotado con #stablehlo.bounds como una codificación de tensor, es decir, un tensor de rango 2 con una dimensión dinámica acotada en 16 y la otra sin una cota se puede representar como tensor<?x?xf32, #stablehlo.bounds<16, ?>>.
StableHLO puede representar el dinamismo limitado, pero hay una compatibilidad limitada con el framework, que se originó en TensorFlow y tiene cierta compatibilidad en PyTorch/XLA.
Dinamismo ilimitado
El dinamismo no acotado, como su nombre lo indica, se refiere a una dimensión dinámica sin un límite conocido en el tamaño. Este tipo de dinamismo es muy común en StableHLO, con compatibilidad con JAX, PyTorch/XLA y TF, y se suele usar para exportar modelos con tamaño de lote o longitud de secuencia dinámicos.
En StableHLO, simplemente omitimos la codificación de límites para esta forma de dinamismo, es decir, tensor<?x?xf32>.
Polimorfismo de formas
El polimorfismo de formas es un término que heredamos de JAX.
El polimorfismo de formas tiene dos implicaciones clave:
- Todo el dinamismo del programa se remonta a sus argumentos de entrada.
- Todo el dinamismo se relaciona solo con las formas de los tensores, es decir, no depende de los datos.
Con estas dos reglas, una vez que se conocen las formas estáticas de un programa, podemos tomar un programa dinámico y refinarlo por completo en un programa estático para la compilación (consulta "Compiler passes for refining dynamic programs").
En general, el polimorfismo de formas usa dinamismo no vinculado. Si las formas de los argumentos conocidos pueden generar un programa completamente estático, no es necesario adivinar cómo vincular los valores.
Dinamismo dependiente de los datos
El dinamismo dependiente de los datos se refiere a los tamaños de dimensiones dinámicas que pertenecen a los datos dentro de un tensor. El ejemplo canónico es una función nonzeros que devuelve los índices de todos los elementos que son 0 en un valor de tensor. La forma no se puede conocer sin evaluar los datos, pero a menudo se puede compilar con dinamismo limitado, lo que implica gastar memoria adicional en el tamaño potencial del tensor de salida.
Muchas operaciones dinámicas que dependen de los datos se pueden modelar con dinamismo limitado, en el que se especifica un límite superior para el tamaño de un tensor y, por lo general, el hardware implementará esto a través del padding de tensores. Actualmente, PyTorch/XLA y TensorFlow admiten cierto dinamismo dependiente de los datos, pero JAX no realiza un seguimiento de las operaciones que generan dinamismo dependiente de los datos.
Exporta programas con dimensiones dinámicas
Consulta nuestros instructivos de StableHLO para obtener información sobre cómo exportar programas con tamaños de lotes o longitudes de secuencia dinámicos:
- Instructivo de JAX > Exportación con tamaño de lote dinámico
- Instructivo de PyTorch/XLA > Export with Dynamic Batch Size
Pases del compilador para refinar programas dinámicos
Quita la canalización de paso de dinamismo
Existen algunos pases útiles para refinar formas, y todos se incluyen en una canalización de pases createStablehloRemoveDynamismPipeline:
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);
Pases individuales para refinar el dinamismo
Individualmente, los pases que suelen ser útiles para el perfeccionamiento de la forma son los siguientes:
stablehlo-refine-argumentspara reemplazar argumentos de entrada con tipos de tensores concretosstablehlo-refine-shapespara propagar la nueva información de la forma del argumento de entrada en todo el programa.stablehlo-canonicalize-dynamismpara reemplazar las operaciones dinámicas por sus variantes estáticasstablehlo-check-shape-assertionspara verificar y quitar las llamadas personalizadas de aserciones de formas.
Consulta la documentación vinculada para obtener información y ejemplos actualizados.
Ejemplo: ¿Cómo es útil el dinamismo y cómo puedo usarlo?
El dinamismo tiene muchos usos. Aquí nos centraremos principalmente en el caso de uso común del polimorfismo de formas: crear una representación flexible del modelo exportado, que generalmente se usa para representar el tamaño del lote o la longitud de la secuencia dinámicos.
Modelo estático de add_one
Usaremos el siguiente modelo add_one simple para demostrar esto:
def add_one(x):
return x + 1
Cuando se rastrea con un tensor<4xf32>, obtenemos el siguiente programa de StableHLO:
// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
%0 = stablehlo.add %arg0, %cst : tensor<4xf32>
return %0 : tensor<4xf32>
}
Este modelo solo funcionará para los argumentos de entrada que tengan una forma tensor<4xf32>. Si alguna vez cambiáramos el tamaño del lote o la longitud de la secuencia, tendríamos que volver a rastrear el código fuente y volver a reducirlo a StableHLO, y no hay garantía de que aún tengamos acceso al código fuente.
Modelo dinámico de add_one
Aquí es donde entra en juego el dinamismo polimórfico de formas. En cambio, JAX y PyTorch/XLA pueden emitir el modelo add_one con IR válida de forma dinámica, que transmitirá la constante para que coincida con la forma de entrada dinámica de la siguiente manera:
// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
%2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %2 : tensor<?xf32>
return %3 : tensor<?xf32>
}
Esta representación del modelo es mucho más flexible y permite la especificación diferida de valores como el tamaño del lote o la longitud de la secuencia. Este modelo se puede implementar en plataformas que admiten formas dinámicas (como AI Edge) o se puede perfeccionar con los pases de dinamismo que se mencionan en esta documentación.
Cómo perfeccionar el modelo dinámico
Por ejemplo, el siguiente orden de pases puede refinar por completo este programa:
stablehlo-opt add_one_dynamic.mlir \
--stablehlo-refine-arguments='types=tensor<16xf32>' \
--stablehlo-refine-shapes \
--stablehlo-canonicalize-dynamism
De forma incremental, así se transforma el programa:
// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
...
%3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %0, %3 : tensor<?xf32>
return %4 : tensor<?xf32>
}
// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<16> : tensor<1xi32>
%0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
%0 = stablehlo.add %arg0, %cst : tensor<16xf32>
return %0 : tensor<16xf32>
}