L'ottimizzazione delle prestazioni è una parte fondamentale della creazione di modelli di machine learning efficienti. Puoi utilizzare lo strumento di profilazione XProf per misurare il rendimento dei tuoi carichi di lavoro di machine learning. XProf ti consente di acquisire tracce dettagliate dell'esecuzione del modello sui dispositivi XLA. Queste tracce possono aiutarti a identificare i colli di bottiglia delle prestazioni, comprendere l'utilizzo del dispositivo e ottimizzare il codice.
Questa guida descrive la procedura per acquisire in modo programmatico una traccia dallo script PyTorch XLA e visualizzarla utilizzando XProf.
Acquisire una traccia in modo programmatico
Puoi acquisire una traccia aggiungendo alcune righe di codice allo script di addestramento
esistente. Lo strumento principale per acquisire una traccia è il modulo torch_xla.debug.profiler, che in genere viene importato con l'alias xp.
1. Avvia il server del profiler
Prima di poter acquisire una traccia, devi avviare il server del profiler. Questo
server viene eseguito in background nello script e raccoglie i dati di tracciamento. Puoi
avviarlo chiamando xp.start_server(<port>) vicino all'inizio del
blocco di esecuzione principale.
2. Definisci la durata della traccia
Racchiudi il codice che vuoi profilare all'interno delle chiamate xp.start_trace() e
xp.stop_trace(). La funzione start_trace accetta un percorso a una directory
in cui vengono salvati i file di traccia.
È prassi comune eseguire il wrapping del ciclo di addestramento principale per acquisire le operazioni più pertinenti.
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Aggiungere etichette di traccia personalizzate
Per impostazione predefinita, le tracce acquisite sono funzioni Pytorch XLA di basso livello e può essere difficile navigare. Puoi aggiungere etichette personalizzate a sezioni specifiche del codice
utilizzando il gestore di contesto xp.Trace(). Queste etichette verranno visualizzate come blocchi denominati
nella visualizzazione della sequenza temporale del profiler, semplificando notevolmente l'identificazione di operazioni specifiche
come la preparazione dei dati, il forward pass o il passaggio dell'ottimizzatore.
Il seguente esempio mostra come aggiungere contesto a diverse parti di un passaggio di addestramento.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Esempio completo
L'esempio seguente mostra come acquisire una traccia da uno script PyTorch XLA,
in base al file mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualizzare la traccia
Al termine dello script, i file di traccia vengono salvati nella directory che hai specificato (ad esempio /root/logs/). Puoi visualizzare questa traccia utilizzando XProf.
Puoi avviare la UI del profiler direttamente utilizzando il comando XProf autonomo puntandolo alla directory dei log:
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Nel browser, vai all'URL fornito (ad es. http://localhost:8791/) per visualizzare il profilo.
Potrai visualizzare le etichette personalizzate che hai creato e analizzare il tempo di esecuzione di diverse parti del modello.
Se utilizzi Google Cloud per eseguire i tuoi carichi di lavoro, ti consigliamo lo strumento cloud-diagnostics-xprof. Offre un'esperienza semplificata di raccolta e visualizzazione dei profili utilizzando VM che eseguono XProf.