Lo stato attuale del dinamismo è descritto in modo più formale nel RFC sul dinamismo. Questa pagina fornisce una panoramica generale del RFC e illustra API e strumenti importanti per interagire con i programmi dinamici.
Terminologia e panoramica dell'assistenza di Dinamismo
In primo luogo, per esaminare alcuni termini che verranno riportati in questo documento, nonché una breve introduzione al loro supporto in StableHLO:
Dimensioni dinamiche
Per dimensioni dinamiche si intendono tutte le dimensioni di cui è sconosciuta la dimensione.
In StableHLO rappresentiamo le dimensioni dinamiche utilizzando ?
, ovvero tensor<16x?xf32>
.
Dinamismo limitato
Dinamismo limitato si riferisce a una dimensione dinamica il cui valore ha un limite superiore noto. In genere, è utile per aggiungere spaziatura al tensore durante l'esecuzione.
In StableHLO rappresentiamo il dinamismo limitato utilizzando #stablehlo.bounds
come codifica del tensore, ovvero un tensore di rango 2 con una dimensione dinamica limitata a 16 e l'altra senza un limite può essere rappresentata come tensor<?x?xf32, #stablehlo.bounds<16, ?>>
.
StableHLO è in grado di rappresentare il dinamismo limitato, ma il supporto del framework è limitato, in quanto è nato in TensorFlow e ha un supporto parziale in PyTorch/XLA.
Dinamismo illimitato
Dinamismo illimitato, come suggerisce il nome, si riferisce a una dimensione dinamica senza limiti noti alla dimensione. Questo tipo di dinamismo è molto comune in StableHLO, con il supporto di JAX, PyTorch/XLA e TF, spesso utilizzato per esportare modelli con dimensioni batch o lunghezza di sequenza dinamiche.
In StableHLO, semplicemente eliminiamo la codifica dei limiti per questa forma di dinamismo, ovvero
tensor<?x?xf32>
.
Polimorfismo delle forme
Il polimorfismo della forma è un termine che abbiamo ereditato dalla JAX.
Esistono due implicazioni chiave per il polimorfismo della forma:
- Tutto il dinamismo del programma si basa sugli argomenti di input.
- Tutto il dinamismo riguarda solo le forme dei tensori, ovvero non dipende dai dati.
Con queste due regole, una volta note le forme statiche di un programma, siamo in grado di prendere un programma dinamico e perfezionarlo completamente in un programma statico per la compilazione (consulta la sezione "Passaggi del compilatore per perfezionare i programmi dinamici").
In genere, il polimorfismo di forma utilizza un dinamismo illimitato. Se le forme degli argomenti noti possono portare a un programma completamente statico, non è necessario indovinare come limitare i valori.
Dinamismo basato sui dati
Il dinamismo dipendente dai dati si riferisce alle dimensioni delle dimensioni dinamiche relative ai dati all'interno di un tensore. L'esempio canonico è una funzione nonzeros
che
restituisce gli indici di tutti gli elementi che sono 0
in un valore del tensore. La forma
non può essere conosciuta senza valutare i dati, ma spesso può essere compilata utilizzando
dinamismo limitato, spendendo più memoria per la potenziale dimensione del tensore di output.
Molte operazioni dinamiche dipendenti dai dati possono essere modellate utilizzando il dinamismo limitato, in cui viene specificato un limite superiore per le dimensioni di un tensore e l'hardware in genere lo implementa tramite il padding del tensore. Attualmente esiste un certo supporto per il dinamismo dipendente dai dati in PyTorch/XLA e TensorFlow, ma JAX al momento non esegue il monitoraggio delle operazioni che portano a un dinamismo dipendente dai dati.
Esportazione di programmi con dimensioni dinamiche
Consulta i nostri tutorial su StableHLO per informazioni su come esportare programmi con dimensioni dei batch o lunghezze di sequenza dinamiche:
- Tutorial JavaScript > Esportazione con dimensione del batch dinamica
- Tutorial PyTorch/XLA > Esportazione con dimensione batch dinamica
Passaggi del compilatore per perfezionare i programmi dinamici
Rimuovere la pipeline di passaggio del dinamismo
Esistono alcuni passaggi utili per perfezionare le forme, che sono tutti riuniti in una pipeline di passaggi createStablehloRemoveDynamismPipeline
:
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);
Singole tessere per perfezionare il dinamismo
Individualmente, le tessere che tendono a essere utili per il perfezionamento della forma sono:
stablehlo-refine-arguments
per sostituire gli argomenti di input con tipi di tensori specifici.stablehlo-refine-shapes
per propagare le informazioni sulla forma del nuovo argomento di input nell'intero programma.stablehlo-canonicalize-dynamism
per sostituire le opzioni dinamiche con le relative varianti statiche.
Per informazioni e esempi aggiornati, consulta la documentazione collegata.
Esempio: in che modo il dinamismo è utile e come posso utilizzarlo?
Il dinamismo ha molti utilizzi. Qui ci concentreremo principalmente sul caso d'uso comune del polimorfismo della forma: creare una rappresentazione flessibile del modello esportato, generalmente utilizzata per rappresentare la dimensione dinamica del batch o la lunghezza della sequenza.
Modello add_one statico
Per la dimostrazione, utilizzeremo il seguente semplice modello add_one
:
def add_one(x):
return x + 1
Se viene tracciato utilizzando un tensor<4xf32>
, otteniamo il seguente programma StableHLO:
// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
%0 = stablehlo.add %arg0, %cst : tensor<4xf32>
return %0 : tensor<4xf32>
}
Questo modello funzionerà solo per gli argomenti di input che hanno una forma tensor<4xf32>
. Se modificassimo la dimensione del batch o la lunghezza della sequenza, dovremmo
ritracciare il codice sorgente e riabbassarlo a StableHLO, senza alcuna garanzia
che abbiamo ancora accesso al codice sorgente.
Modello Add_one dinamico
È qui che entra in gioco il dinamismo polimorfo della forma. Invece, JAX e
PyTorch/XLA possono emettere il modello add_one
con un IR valido dinamicamente che
trasmetterà la costante in modo che corrisponda alla forma di input dinamica come segue:
// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
%2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %2 : tensor<?xf32>
return %3 : tensor<?xf32>
}
Questa rappresentazione del modello è molto più flessibile e consente la specifica differita di valori come la dimensione del batch o la lunghezza della sequenza. Questo modello può essere implementato su piattaforme con supporto delle forme dinamiche (come AI Edge) o perfezionato utilizzando i passaggi di dinamismo descritti in questa documentazione.
Perfezionare il modello dinamico
Ad esempio, l'ordinamento dei permessi seguente può perfezionare completamente questo programma:
stablehlo-opt add_one_dynamic.mlir \
--stablehlo-refine-arguments='types=tensor<16xf32>' \
--stablehlo-refine-shapes \
--stablehlo-canonicalize-dynamism
Ecco come viene trasformato il programma in modo incrementale:
// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
...
%3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %0, %3 : tensor<?xf32>
return %4 : tensor<?xf32>
}
// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<16> : tensor<1xi32>
%0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
%0 = stablehlo.add %arg0, %cst : tensor<16xf32>
return %0 : tensor<16xf32>
}