Dinamismo no StableHLO

O estado atual do dinamismo está descrito de maneira mais formal na RFC de dinamismo. Esta página oferece uma visão geral de alto nível da RFC e discute APIs e ferramentas importantes para interagir com programas dinâmicos.

Visão geral da terminologia e do suporte do dinamismo

Primeiro, para abordar alguns termos que vão aparecer neste documento, além de uma breve introdução ao suporte deles no StableHLO:

Dimensões dinâmicas

Dimensões dinâmicas são aquelas cujo tamanho é desconhecido. No StableHLO, representamos dimensões dinâmicas usando ?, ou seja, tensor<16x?xf32>.

Dinâmica limitada

O dinamismo limitado se refere a uma dimensão dinâmica cujo valor tem um limite superior conhecido. Em geral, isso é útil para fazer padding no tensor durante a execução. No StableHLO, representamos o dinamismo limitado usando #stablehlo.bounds como uma codificação de tensor, ou seja, um tensor de classificação 2 com uma dimensão dinâmica limitada em 16 e a outra sem limite pode ser representada como tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

O StableHLO consegue representar o dinamismo limitado, mas há suporte limitado do framework, originado no TensorFlow e com algum suporte no PyTorch/XLA.

Dinamismo ilimitado

O dinamismo sem limites, como o nome sugere, se refere a uma dimensão dinâmica sem um limite conhecido de tamanho. Esse tipo de dinamismo é muito comum no StableHLO, com suporte para JAX, PyTorch/XLA e TF, geralmente usado para exportar modelos com tamanho de lote ou comprimento de sequência dinâmicos.

No StableHLO, simplesmente omitimos a codificação de limites para essa forma de dinamismo, ou seja, tensor<?x?xf32>.

Polimorfismo de forma

O polimorfismo de forma é um termo que herdamos do JAX.

Há duas implicações principais para o polimorfismo de forma:

  1. Todo o dinamismo do programa remonta aos argumentos de entrada.
  2. Todo o dinamismo se refere apenas a formatos de tensor, ou seja, não depende de dados.

Com essas duas regras, assim que as formas estáticas de um programa são conhecidas, podemos pegar um programa dinâmico e refiná-lo completamente em um programa estático para compilação (consulte "Passagens do compilador para refinar programas dinâmicos").

Em geral, o polimorfismo de forma usa dinamismo ilimitado. Se as formas de argumento conhecidas puderem levar a um programa totalmente estático, não será necessário adivinhar como limitar os valores.

Dinamismo dependente de dados

O dinamismo dependente de dados se refere a tamanhos de dimensões dinâmicas relacionados aos dados dentro de um tensor. O exemplo canônico é uma função nonzeros que retorna os índices de todos os elementos que são 0 em um valor de tensor. Não é possível conhecer a forma sem avaliar os dados, mas ela pode ser compilada usando dinamismo limitado, gastando memória extra no tamanho potencial do tensor de saída.

Muitas operações dinâmicas dependentes de dados podem ser modeladas usando dinamismo limitado, em que um limite superior no tamanho de um tensor é especificado, e o hardware geralmente implementa isso usando padding de tensor. Atualmente, há suporte para dinamismo dependente de dados em PyTorch/XLA e TensorFlow, mas o JAX não rastreia operações que levam a esse tipo de dinamismo.

Exportar programas com dimensões dinâmicas

Consulte nossos tutoriais do StableHLO para saber como exportar programas com tamanhos de lote ou comprimentos de sequência dinâmicos:

Passagens do compilador para refinar programas dinâmicos

Remover pipeline de transmissão de dinamismo

Há algumas transmissões úteis para refinar formas. Todas elas estão agrupadas em um pipeline de transmissão createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Passagens individuais para refinar o dinamismo

Individualmente, as transmissões que tendem a ser úteis para o refinamento de forma são:

Consulte a documentação vinculada para informações e exemplos atualizados.

Exemplo: como o dinamismo é útil e como posso usá-lo?

O dinamismo tem muitos usos. Aqui, vamos nos concentrar principalmente no caso de uso comum para polimorfismo de forma: criar uma representação de modelo exportado flexível, geralmente usada para representar tamanho de lote ou comprimento de sequência dinâmicos.

Modelo estático add_one

Vamos usar o seguinte modelo add_one simples para demonstrar isso:

def add_one(x):
  return x + 1

Quando rastreado usando um tensor<4xf32>, vamos receber o seguinte programa StableHLO:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Esse modelo funciona apenas para argumentos de entrada que têm uma forma tensor<4xf32>. Se mudarmos o tamanho do lote ou o comprimento da sequência, precisaremos rastrear novamente o código-fonte e reduzir para StableHLO. Não há garantia de que ainda teremos acesso ao código-fonte.

Modelo dinâmico add_one

É aí que entra o dinamismo polimórfico de forma. Em vez disso, o JAX e o PyTorch/XLA podem emitir o modelo add_one com IR dinamicamente válido, que vai transmitir a constante para corresponder ao formato de entrada dinâmica da seguinte maneira:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Essa representação de modelo é muito mais flexível e permite a especificação adiada de valores como tamanho do lote ou comprimento da sequência. Esse modelo pode ser implantado em plataformas com suporte a formas dinâmicas (como o AI Edge) ou refinado usando as transmissões de dinamismo mencionadas nesta documentação.

Aprimorar o modelo dinâmico

Por exemplo, a seguinte ordenação de transmissões pode refinar totalmente este programa:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

De forma incremental, é assim que o programa é transformado:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}