Dynamik in StableHLO

Der aktuelle Stand der Dynamik wird im Dynamism RFC förmlicher beschrieben. Diese Seite bietet einen allgemeinen Überblick über den RFC und erörtert wichtige APIs und Tools für die Interaktion mit dynamischen Programmen.

Dynamismus – Begriffe und Supportübersicht

Zuerst einige Begriffe, die in diesem Dokument vorkommen, sowie eine kurze Einführung in ihre Unterstützung in StableHLO:

Dynamische Dimensionen

Dynamische Dimensionen sind Dimensionen, deren Größe unbekannt ist. In StableHLO werden dynamische Dimensionen mit ? dargestellt, also tensor<16x?xf32>.

Begrenzte Dynamik

Begrenzte Dynamik bezieht sich auf eine dynamische Dimension, deren Wert eine bekannte Obergrenze hat. Im Allgemeinen ist dies nützlich, um den Tensor während der Ausführung zu füllen. In StableHLO wird begrenzter Dynamismus mit #stablehlo.bounds als Tensorcodierung dargestellt, d.h. ein Tensor mit Rang 2 mit einer dynamischen Dimension, die auf 16 beschränkt ist, und die andere ohne Grenze kann als tensor<?x?xf32, #stablehlo.bounds<16, ?>> dargestellt werden.

StableHLO kann begrenzte Dynamik darstellen, aber es gibt nur begrenzte Framework-Unterstützung, die aus TensorFlow stammt, und es gibt einige Unterstützung in PyTorch/XLA.

Unbegrenzte Dynamik

„Unbegrenzt dynamisch“ bezieht sich, wie der Name schon sagt, auf eine dynamische Dimension ohne bekannte Obergrenze für die Größe. Diese Art von Dynamik ist in StableHLO sehr häufig. Mit JAX-, PyTorch/XLA- und TF-Unterstützung wird sie häufig zum Exportieren von Modellen mit dynamischer Batchgröße oder Sequenzlänge verwendet.

In StableHLO wird die Begrenzungscodierung für diese Form der Dynamik einfach entfernt, d.h. tensor<?x?xf32>.

Formpolymorphismus

Der Begriff „Formpolymorphismus“ stammt aus JAX.

Der Formpolymorphismus hat zwei wichtige Auswirkungen:

  1. Die gesamte Dynamik des Programms geht auf seine Eingabeargumente zurück.
  2. Die Dynamik bezieht sich nur auf die Form von Tensoren, d.h. sie ist nicht datenabhängig.

Sobald die statischen Formen eines Programms bekannt sind, können wir mit diesen beiden Regeln ein dynamisches Programm vollständig zu einem statischen Programm für die Kompilierung optimieren (siehe „Compilerdurchläufe zum Optimieren dynamischer Programme“).

Im Allgemeinen wird bei der Formpolymorphie eine unbegrenzte Dynamik verwendet. Wenn bekannte Argumentformen zu einem vollständig statischen Programm führen können, müssen die Werte nicht gebunden werden.

Datenabhängige Dynamik

Datenabhängige Dynamik bezieht sich auf die Größe dynamischer Dimensionen, die sich auf die Daten in einem Tensor beziehen. Das kanonische Beispiel ist eine nonzeros-Funktion, die die Indizes aller Elemente zurückgibt, die in einem Tensorwert 0 sind. Die Form kann nicht ohne Auswertung der Daten bekannt sein, kann aber oft mit begrenzter Dynamik kompiliert werden, wobei zusätzlicher Arbeitsspeicher für die potenzielle Größe des Ausgabetensors verbraucht wird.

Viele datenabhängige dynamische Operationen können mit begrenzter Dynamik modelliert werden, wobei eine Obergrenze für eine Tensorgröße angegeben ist. Die Hardware implementiert dies im Allgemeinen über Tensor-Padding. Derzeit gibt es in PyTorch/XLA und TensorFlow eine gewisse Unterstützung für datenabhängige Dynamik. In JAX werden derzeit jedoch keine Vorgänge erfasst, die zu datenabhängiger Dynamik führen.

Programme mit dynamischen Dimensionen exportieren

In unseren StableHLO-Anleitungen erfahren Sie, wie Sie Programme mit dynamischen Batchgrößen oder Sequenzlängen exportieren:

Compilerdurchläufe zur Optimierung dynamischer Programme

Dynamische Karten/Tickets-Pipeline entfernen

Es gibt einige nützliche Pässe zum Optimieren von Formen. Sie sind alle in einer Pass-Pipeline createStablehloRemoveDynamismPipeline zusammengefasst:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Einzelne Karten/Tickets für eine optimierte Dynamik

Für die Formoptimierung eignen sich folgende Pässe:

Aktuelle Informationen und Beispiele finden Sie in der verlinkten Dokumentation.

Beispiel: Wie ist Dynamik nützlich und wie kann ich sie nutzen?

Dynamik hat viele Anwendungsfälle. Hier konzentrieren wir uns hauptsächlich auf den gängigen Anwendungsfall für den Shape-Polymorphismus: das Erstellen einer flexiblen exportierten Modelldarstellung, die in der Regel zur Darstellung der dynamischen Batchgröße oder Sequenzlänge verwendet wird.

Statisches Modell "add_one"

Wir verwenden dazu das folgende einfache add_one-Modell:

def add_one(x):
  return x + 1

Wenn wir mit einer tensor<4xf32> verfolgen, erhalten wir das folgende StableHLO-Programm:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Dieses Modell funktioniert nur für Eingabeargumente mit der Form tensor<4xf32>. Wenn wir die Batchgröße oder die Sequenzlänge ändern, müssten wir den Quellcode noch einmal zurückverfolgen und wieder auf StableHLO herunterstufen. Es gibt keine Garantie dafür, dass wir noch Zugriff auf den Quellcode haben.

Dynamisches Modell „add_one“

Hier kommt die polymorphe Dynamik der Form ins Spiel. Stattdessen können JAX und PyTorch/XLA das add_one-Modell mit dynamisch gültiger IR ausgeben, die die Konstante so überträgt, dass sie der dynamischen Eingabeform entspricht:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Diese Modelldarstellung ist viel flexibler und ermöglicht die spätere Angabe von Werten wie Batchgröße oder Sequenzlänge. Dieses Modell kann auf Plattformen mit dynamischer Formunterstützung (z. B. AI Edge) bereitgestellt oder mithilfe der in dieser Dokumentation erwähnten Dynamik-Tickets optimiert werden.

Dynamisches Modell optimieren

Mit der folgenden Karten-/Ticketreihenfolge kann dieses Programm beispielsweise vollständig optimiert werden:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

So wird das Programm schrittweise umgewandelt:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}