Dinamismo no StableHLO

O estado atual do dinamismo é descrito de maneira mais formal no RFC sobre dinamismo. Esta página vai fornecer uma visão geral de alto nível do RFC e discutir APIs e ferramentas importantes para interagir com programas dinâmicos.

Termos e visão geral do suporte do Dynamism

Primeiro, vamos abordar alguns termos que vão aparecer neste documento, além de uma breve introdução ao suporte deles no StableHLO:

Dimensões dinâmicas

Dimensões dinâmicas são aquelas cujo tamanho é desconhecido. No StableHLO, representamos dimensões dinâmicas usando ?, ou seja, tensor<16x?xf32>.

Dinamismo limitado

O dinamismo limitado se refere a uma dimensão dinâmica cujo valor tem um limite superior conhecido. Geralmente, isso é útil para preencher o tensor durante a execução. No StableHLO, representamos o dinamismo limitado usando #stablehlo.bounds como uma codificação de tensor, ou seja, um tensor de rank 2 com uma dimensão dinâmica limitada a 16 e a outra sem limite pode ser representada como tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

O StableHLO pode representar o dinamismo limitado, mas há suporte limitado ao framework, originado no TensorFlow, e com algum suporte ao PyTorch/XLA.

Dinamismo ilimitado

O dinamismo ilimitado, como o nome indica, se refere a uma dimensão dinâmica sem limite conhecido de tamanho. Esse tipo de dinamismo é muito comum no StableHLO, com suporte a JAX, PyTorch/XLA e TF, geralmente usado para exportar modelos com tamanho de lote dinâmico ou comprimento de sequência.

No StableHLO, simplesmente omitimos a codificação de limites para essa forma de dinamismo, ou seja, tensor<?x?xf32>.

Polimorfismo de forma

O polimorfismo de forma é um termo que herdamos do JAX.

Há duas implicações principais para moldar o polimorfismo:

  1. Todo o dinamismo no programa é rastreado de volta aos argumentos de entrada.
  2. Todo o dinamismo se refere apenas a formas de tensor, ou seja, não depende de dados.

Com essas duas regras, depois que as formas estáticas de um programa são conhecidas, podemos transformar um programa dinâmico em um programa estático para compilação (consulte "Passes de compilador para refinar programas dinâmicos").

Geralmente, o polimorfismo de formas usa dinamismo ilimitado. Se formas de argumento conhecidas puderem levar a um programa totalmente estático, não será necessário adivinhar como vincular os valores.

Dinamismo que depende dos dados

O dinamismo dependente de dados se refere aos tamanhos de dimensões dinâmicas que pertencem aos dados dentro de um tensor. O exemplo canônico é uma função nonzeros que retorna os índices de todos os elementos que são 0 em um valor de tensor. A forma não pode ser conhecida sem avaliar os dados, mas muitas vezes pode ser compilada usando o dinamismo limitado, gastando memória extra no tamanho potencial do tensor de saída.

Muitas operações dinâmicas dependentes de dados podem ser modeladas usando o dinamismo limitado, em que um limite superior em um tamanho de tensor é especificado, e o hardware geralmente implementa isso por meio do preenchimento do tensor. Atualmente, há alguma compatibilidade com o dinamismo dependente de dados no PyTorch/XLA e no TensorFlow, mas o JAX atualmente não rastreia operações que levam ao dinamismo dependente de dados.

Como exportar programas com dimensões dinâmicas

Consulte nossos tutoriais do StableHLO para informações sobre como exportar programas com tamanhos de lote dinâmicos ou comprimentos de sequência:

Passes do compilador para refinar programas dinâmicos

Remover pipeline de passagem de dinamismo

Há alguns passes úteis para refinar formas. Por conveniência, todos eles estão agrupados em um pipeline de passagem createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Cartões individuais para refinar o dinamismo

Individualmente, os passes que tendem a ser úteis para o refinamento de formas são:

Acesse a documentação vinculada para conferir informações atualizadas e exemplos.

Exemplo: como o dinamismo é útil e como posso usá-lo?

O Dynamism tem muitos usos. Aqui, vamos nos concentrar principalmente no caso de uso comum de polimorfismo de forma: criar uma representação flexível de modelo exportado, geralmente usada para representar o tamanho de lote dinâmico ou o comprimento da sequência.

Modelo estático add_one

Usaremos o seguinte modelo simples de add_one para demonstrar isso:

def add_one(x):
  return x + 1

Quando rastreado usando um tensor<4xf32>, vamos receber o seguinte programa StableHLO:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Esse modelo funciona com argumentos de entrada que têm uma forma tensor<4xf32>. Se mudássemos o tamanho do lote ou o comprimento da sequência, precisaríamos refazer o rastreamento do código-fonte e rebaixar para o StableHLO. Não há garantia de que ainda temos acesso ao código-fonte.

Modelo dinâmico add_one

É aqui que o dinamismo polimórfico de formas entra em jogo. Em vez disso, o JAX e o PyTorch/XLA podem emitir o modelo add_one com IR válido dinamicamente, que vai transmitir a constante para corresponder à forma de entrada dinâmica da seguinte maneira:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Essa representação de modelo é muito mais flexível e permite a especificação adiada de valores como tamanho do lote ou comprimento da sequência. Esse modelo pode ser implantado em plataformas com suporte a formas dinâmicas (como o AI Edge) ou refinado usando as passagens de dinamismo mencionadas nesta documentação.

Aprimorar o modelo dinâmico

Por exemplo, a seguinte ordem de passagem pode refinar totalmente esse programa:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Incrementalmente, o programa é transformado da seguinte forma:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}