Dynamik in StableHLO

Der aktuelle Stand der Dynamik wird im Dynamism RFC genauer beschrieben. Auf dieser Seite finden Sie einen allgemeinen Überblick über den RFC sowie wichtige APIs und Tools für die Interaktion mit dynamischen Programmen.

Dynamik – Terminologie und Supportübersicht

Zuerst einige Begriffe, die in diesem Dokument vorkommen, sowie eine kurze Einführung in die Unterstützung in StableHLO:

Dynamische Dimensionen

„Dynamische Dimensionen“ bezieht sich auf Dimensionen, deren Größe unbekannt ist. In StableHLO werden dynamische Dimensionen mit ? dargestellt, d.h. tensor<16x?xf32>.

Begrenzte Dynamik

„Begrenzte Dynamik“ bezieht sich auf eine dynamische Dimension, deren Wert eine bekannte Obergrenze hat. Im Allgemeinen ist dies nützlich, um den Tensor während der Ausführung aufzufüllen. In StableHLO stellen wir die begrenzte Dynamik mit #stablehlo.bounds als Tensorcodierung dar. Ein Tensor vom Rang 2 mit einer dynamischen Dimension, die auf 16 begrenzt ist, und der anderen ohne Begrenzung kann als tensor<?x?xf32, #stablehlo.bounds<16, ?>> dargestellt werden.

StableHLO kann begrenzten Dynamismus darstellen, aber es gibt nur eingeschränkte Framework-Unterstützung, die aus TensorFlow stammt und in PyTorch/XLA teilweise unterstützt wird.

Unbegrenzte Dynamik

Wie der Name schon sagt, bezieht sich „Unbounded dynamism“ auf eine dynamische Dimension ohne bekannte Größenbeschränkung. Diese Art von Dynamik ist in StableHLO sehr häufig und wird mit Unterstützung für JAX, PyTorch/XLA und TF oft zum Exportieren von Modellen mit dynamischer Batchgröße oder Sequenzlänge verwendet.

In StableHLO wird die Grenzenkodierung für diese Form der Dynamik einfach ausgelassen, d.h. tensor<?x?xf32>.

Formpolymorphismus

Die Form des Polymorphismus ist ein Begriff, den wir von JAX übernommen haben.

Die Form von Polymorphismus hat zwei wichtige Auswirkungen:

  1. Die gesamte Dynamik im Programm geht auf die Eingabeargumente zurück.
  2. Die gesamte Dynamik bezieht sich nur auf Tensorformen, d.h., sie ist nicht datenabhängig.

Mit diesen beiden Regeln können wir ein dynamisches Programm vollständig in ein statisches Programm für die Kompilierung umwandeln, sobald die statischen Formen eines Programms bekannt sind (siehe „Compiler-Durchläufe zum Optimieren dynamischer Programme“).

Im Allgemeinen wird bei der Formpolymorphie eine unbegrenzte Dynamik verwendet. Wenn die Argumentformen bekannt sind, kann dies zu einem vollständig statischen Programm führen. Es ist nicht erforderlich, die Werte zu schätzen.

Datengestützte Dynamik

Die datenabhängige Dynamik bezieht sich auf dynamische Dimensionsgrößen, die sich auf die Daten in einem Tensor beziehen. Das kanonische Beispiel ist eine nonzeros-Funktion, die die Indexe aller Elemente zurückgibt, die in einem Tensorwert 0 sind. Die Form kann nicht ohne Auswertung der Daten ermittelt werden. Sie kann jedoch häufig mithilfe von begrenzter Dynamik zusammengestellt werden, wobei zusätzlicher Speicher für die potenzielle Ausgabetensorgröße verwendet wird.

Viele datenabhängige dynamische Vorgänge können mit begrenzter Dynamik modelliert werden. Dabei wird eine Obergrenze für die Tensorgröße angegeben und die Hardware implementiert dies in der Regel durch Tensor-Padding. Derzeit gibt es in PyTorch/XLA und TensorFlow eine gewisse Unterstützung für datenabhängige Dynamik, aber JAX verfolgt derzeit keine Operationen, die zu datenabhängiger Dynamik führen.

Programme mit dynamischen Dimensionen exportieren

In unseren StableHLO-Tutorials finden Sie Informationen zum Exportieren von Programmen mit dynamischen Batchgrößen oder Sequenzlängen:

Compiler-Durchläufe zum Optimieren dynamischer Programme

Dynamik-Pass-Pipeline entfernen

Es gibt einige nützliche Durchgänge zum Verfeinern von Formen. Sie sind alle in einer Pass-Pipeline createStablehloRemoveDynamismPipeline gebündelt:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Einzelne Durchgänge zum Verfeinern der Dynamik

Die folgenden Durchgänge sind in der Regel nützlich, um die Form zu optimieren:

Aktuelle Informationen und Beispiele finden Sie in der verlinkten Dokumentation.

Beispiel: Wie kann ich Dynamik nutzen?

Dynamismus hat viele Anwendungsfälle. Hier konzentrieren wir uns hauptsächlich auf den häufigen Anwendungsfall für Shape-Polymorphismus: die Erstellung einer flexiblen exportierten Modellrepräsentation, die in der Regel zur Darstellung der dynamischen Batchgröße oder Sequenzlänge verwendet wird.

Statisches „add_one“-Modell

Wir verwenden das folgende einfache add_one-Modell, um dies zu veranschaulichen:

def add_one(x):
  return x + 1

Wenn wir das Programm mit einem tensor<4xf32> verfolgen, erhalten wir das folgende StableHLO-Programm:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Dieses Modell funktioniert nur für Eingabeargumente mit der Form tensor<4xf32>. Wenn wir jemals unsere Batchgröße oder Sequenzlänge geändert hätten, müssten wir den Quellcode noch einmal durchlaufen und in StableHLO umwandeln. Es gibt keine Garantie, dass wir überhaupt noch Zugriff auf den Quellcode haben.

Dynamisches „add_one“-Modell

Hier kommt die Formpolymorphie ins Spiel. Stattdessen können JAX und PyTorch/XLA das add_one-Modell mit dynamisch gültiger IR ausgeben, die die Konstante so broadcastet, dass sie der dynamischen Eingabeform entspricht:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Diese Modelldarstellung ist viel flexibler und ermöglicht die verzögerte Angabe von Werten wie Batchgröße oder Sequenzlänge. Dieses Modell kann auf Plattformen mit Unterstützung für dynamische Formen (z. B. AI Edge) bereitgestellt oder mithilfe der in dieser Dokumentation erwähnten Dynamismus-Passes optimiert werden.

Dynamisches Modell optimieren

Mit der folgenden Reihenfolge von Durchgängen kann dieses Programm beispielsweise vollständig optimiert werden:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

So wird das Programm nach und nach transformiert:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}