Profilare i carichi di lavoro PyTorch XLA

L'ottimizzazione delle prestazioni è una parte fondamentale della creazione di modelli di machine learning efficienti. Puoi utilizzare lo strumento di profilazione XProf per misurare il rendimento dei tuoi carichi di lavoro di machine learning. XProf ti consente di acquisire tracce dettagliate dell'esecuzione del modello sui dispositivi XLA. Queste tracce possono aiutarti a identificare i colli di bottiglia delle prestazioni, comprendere l'utilizzo del dispositivo e ottimizzare il codice.

Questa guida descrive la procedura per acquisire in modo programmatico una traccia dallo script PyTorch XLA e visualizzarla utilizzando XProf.

Acquisire una traccia in modo programmatico

Puoi acquisire una traccia aggiungendo alcune righe di codice allo script di addestramento esistente. Lo strumento principale per acquisire una traccia è il modulo torch_xla.debug.profiler, che in genere viene importato con l'alias xp.

1. Avvia il server del profiler

Prima di poter acquisire una traccia, devi avviare il server del profiler. Questo server viene eseguito in background nello script e raccoglie i dati di tracciamento. Puoi avviarlo chiamando xp.start_server(<port>) vicino all'inizio del blocco di esecuzione principale.

2. Definisci la durata della traccia

Racchiudi il codice che vuoi profilare all'interno delle chiamate xp.start_trace() e xp.stop_trace(). La funzione start_trace accetta un percorso a una directory in cui vengono salvati i file di traccia.

È prassi comune eseguire il wrapping del ciclo di addestramento principale per acquisire le operazioni più pertinenti.

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Aggiungere etichette di traccia personalizzate

Per impostazione predefinita, le tracce acquisite sono funzioni Pytorch XLA di basso livello e può essere difficile navigare. Puoi aggiungere etichette personalizzate a sezioni specifiche del codice utilizzando il gestore di contesto xp.Trace(). Queste etichette verranno visualizzate come blocchi denominati nella visualizzazione della sequenza temporale del profiler, semplificando notevolmente l'identificazione di operazioni specifiche come la preparazione dei dati, il forward pass o il passaggio dell'ottimizzatore.

Il seguente esempio mostra come aggiungere contesto a diverse parti di un passaggio di addestramento.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Esempio completo

L'esempio seguente mostra come acquisire una traccia da uno script PyTorch XLA, in base al file mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Visualizzare la traccia

Al termine dello script, i file di traccia vengono salvati nella directory che hai specificato (ad esempio /root/logs/). Puoi visualizzare questa traccia utilizzando XProf.

Puoi avviare la UI del profiler direttamente utilizzando il comando XProf autonomo puntandolo alla directory dei log:

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Nel browser, vai all'URL fornito (ad es. http://localhost:8791/) per visualizzare il profilo.

Potrai visualizzare le etichette personalizzate che hai creato e analizzare il tempo di esecuzione di diverse parti del modello.

Se utilizzi Google Cloud per eseguire i tuoi carichi di lavoro, ti consigliamo lo strumento cloud-diagnostics-xprof. Offre un'esperienza semplificata di raccolta e visualizzazione dei profili utilizzando VM che eseguono XProf.