Dynamizm w StableHLO

Aktualny stan dynamizmu jest opisany w bardziej formalny sposób w dokumentacji Dynamism RFC. Na tej stronie znajdziesz ogólny przegląd dokumentu RFC oraz informacje o ważnych interfejsach API i narzędziach do interakcji z programami dynamicznymi.

Terminologia i przegląd pomocy dotyczącej Dynamism

Najpierw omówimy kilka terminów, które pojawią się w tym dokumencie, oraz przedstawimy krótką charakterystykę pomocy w StableHLO:

Wymiary dynamiczne

Wymiary dynamiczne to dowolne wymiary, których rozmiar jest nieznany. W StableHLO wymiary dynamiczne są reprezentowane za pomocą ?, czyli tensor<16x?xf32>.

Dynamizm ograniczony

Dynamizm ograniczony odnosi się do wymiaru dynamicznego, którego wartość ma znaną górną granicę. Zwykle jest to przydatne do wypełniania tensora podczas wykonywania. W StableHLO ograniczony dynamizm jest reprezentowany za pomocą #stablehlo.bounds jako kodowania tensora, czyli tensora 2-rzędowego z jednym wymiarem dynamicznym ograniczonym do 16 i drugim bez ograniczenia, który może być reprezentowany jako tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO może reprezentować ograniczony dynamizm, ale w ramach TensorFlow dostępna jest ograniczona obsługa tej platformy, a w PyTorch/XLA – ograniczona obsługa.

Nieograniczona dynamika

Jak sama nazwa wskazuje, dynamiczny dynamizm odnosi się do wymiaru dynamicznego bez znanej granicy rozmiaru. Taka dynamika jest bardzo powszechna w StableHLO z obsługą JAX, PyTorch/XLA i TF, która jest często używana przy eksportowaniu modeli o dynamicznym rozmiarze wsadu lub długości sekwencji.

W StableHLO pomijamy kodowanie granic dla tej formy dynamizmu, czyli:tensor<?x?xf32>.

Polimorfizm kształtu

Polimorfizm kształtów to termin odziedziczony z języka JAX.

Polimorfizm kształtu ma 2 kluczowe konsekwencje:

  1. Cała dynamika w programie sięga do argumentów wejściowych.
  2. Cały dynamizm dotyczy tylko kształtów tensora, czyli nie zależy od danych.

Dzięki tym dwóm regułom, gdy znane są już statyczne kształty programu, możemy przekształcić program dynamiczny w program statyczny do skompilowania (patrz „Przebiegi kompilatora służące do udoskonalania programów dynamicznych”).

Ogólnie polimorfizm kształtu używa nieograniczonego dynamizmu. Jeśli znane kształty argumentów mogą prowadzić do programu w pełni statycznego, nie trzeba się zastanawiać, jak ograniczyć wartości.

Dynamizm zależny od danych

Dynamizm zależny od danych odnosi się do rozmiarów wymiarów dynamicznych, które dotyczą danych wewnątrz tensora. Przykładem kanonicznym jest funkcja nonzeros, która zwraca indeksy wszystkich elementów, które są 0 w wartości tensora. Kształt nie jest znany bez oceny danych, ale często można go skompilować przy użyciu ograniczonego dynamizmu, zużywając dodatkową pamięć na potencjalny rozmiar tensora wyjściowego.

Wiele operacji dynamicznych zależnych od danych można modelować za pomocą ograniczonego dynamizmu, w którym określa się górną granicę rozmiaru tensora, a sprzęt będzie to wdrażać za pomocą wypełnienia tensora. Obecnie w PyTorch/XLA i TensorFlow istnieje ograniczone wsparcie dla dynamiki zależnej od danych, ale JAX nie śledzi obecnie operacji, które prowadzą do dynamiki zależnej od danych.

Eksportowanie programów z wymiarami dynamicznymi

Informacje na temat eksportowania programów z dynamicznymi rozmiarami wsadu lub długościami sekwencji znajdziesz w naszych samouczkach dotyczących StableHLO:

Przebiegi kompilatora służące do ulepszania programów dynamicznych

Usuwanie potoku przetwarzania danych o dynamizmie

Jest kilka przydatnych kart do zawężania kształtów. W wygodny sposób są one połączone w potoku kart createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

pojedyncze karty do dopracowywania dynamiki;

Indywidualnie karty, które zwykle są przydatne do zawężania kształtu, to:

Aktualne informacje i przykłady znajdziesz w linkowanej dokumentacji.

Przykład: Jakie zalety ma dynamizm i jak mogę go wykorzystać?

Dynamizm ma wiele zastosowań. Tutaj skupimy się głównie na typowych zastosowaniach polimorfizmu kształtu, czyli tworzeniu elastycznej, wyeksportowanej reprezentacji modelu, stosowanej zwykle do reprezentowania dynamicznej wielkości wsadu lub długości sekwencji.

Statyczny model add_one

Na potrzeby tego przykładu użyjemy prostego modelu add_one:

def add_one(x):
  return x + 1

Gdy śledzenie odbywa się za pomocą tensor<4xf32>, otrzymujemy następujący program StableHLO:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Ten model działa tylko z argumentami wejściowymi o kształcie tensor<4xf32>. Gdybyśmy zmienili rozmiar wsadu lub długość sekwencji, musielibyśmy jeszcze raz śledzić kod źródłowy i z powrotem przejść do wersji StableHLO. Nie ma gwarancji, że nadal będziemy mieli dostęp do kodu źródłowego.

Dynamiczny model dodatkowy

Właśnie w takich sytuacjach przydaje się dynamiczna polimorfizm kształtu. Zamiast tego JAX i PyTorch/XLA mogą emitować model add_one z dynamicznie prawidłowym interfejsem IR, który będzie nadawać stałą wartość, aby dopasować ją do dynamicznego kształtu wejścia w ten sposób:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Ta reprezentacja modelu jest znacznie bardziej elastyczna i umożliwia odroczone specyfikacje wartości, takich jak rozmiar wsadu lub długość sekwencji. Ten model można wdrożyć na platformach z obsługą dynamicznych kształtów (takich jak AI Edge) lub doprecyzować za pomocą kart dynamiki opisanych w tej dokumentacji.

Ulepszanie modelu dynamicznego

Na przykład kolejność kart może w pełni zmienić działanie programu:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Oto jak program jest stopniowo przekształcany:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}