Membuat profil workload PyTorch XLA

Pengoptimalan performa adalah bagian penting dalam membangun model machine learning yang efisien. Anda dapat menggunakan alat pembuatan profil XProf untuk mengukur performa workload machine learning Anda. XProf memungkinkan Anda merekam rekaman aktivitas mendetail dari eksekusi model di perangkat XLA. Rekaman aktivitas ini dapat membantu Anda mengidentifikasi hambatan performa, memahami pemanfaatan perangkat, dan mengoptimalkan kode.

Panduan ini menjelaskan proses pengambilan rekaman aktivitas secara terprogram dari skrip PyTorch XLA dan visualisasi menggunakan XProf.

Merekam aktivitas secara terprogram

Anda dapat merekam rekaman aktivitas dengan menambahkan beberapa baris kode ke skrip pelatihan yang ada. Alat utama untuk merekam aktivitas adalah modul torch_xla.debug.profiler, yang biasanya diimpor dengan alias xp.

1. Mulai server profiler

Sebelum dapat merekam aktivitas, Anda harus memulai server profiler. Server ini berjalan di latar belakang skrip Anda dan mengumpulkan data rekaman aktivitas. Anda dapat memulainya dengan memanggil xp.start_server(<port>) di dekat awal blok eksekusi utama.

2. Menentukan durasi rekaman aktivitas

Gabungkan kode yang ingin Anda buat profilnya dalam panggilan xp.start_trace() dan xp.stop_trace(). Fungsi start_trace mengambil jalur ke direktori tempat file rekaman aktivitas disimpan.

Praktik umumnya adalah membungkus loop pelatihan utama untuk merekam operasi yang paling relevan.

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Menambahkan label rekaman aktivitas kustom

Secara default, rekaman aktivitas yang diambil adalah fungsi Pytorch XLA tingkat rendah dan mungkin sulit dinavigasi. Anda dapat menambahkan label kustom ke bagian tertentu pada kode menggunakan pengelola konteks xp.Trace(). Label ini akan muncul sebagai blok bernama di tampilan linimasa profiler, sehingga memudahkan identifikasi operasi tertentu seperti penyiapan data, forward pass, atau langkah pengoptimal.

Contoh berikut menunjukkan cara menambahkan konteks ke berbagai bagian langkah pelatihan.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Contoh lengkap

Contoh berikut menunjukkan cara merekam aktivitas dari skrip PyTorch XLA, berdasarkan file mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Memvisualisasikan rekaman aktivitas

Setelah skrip Anda selesai, file rekaman aktivitas akan disimpan di direktori yang Anda tentukan (misalnya, /root/logs/). Anda dapat memvisualisasikan rekaman aktivitas ini menggunakan XProf.

Anda dapat meluncurkan UI profiler secara langsung menggunakan perintah XProf mandiri dengan mengarahkan perintah tersebut ke direktori log Anda:

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Buka URL yang diberikan (misalnya, http://localhost:8791/) di browser Anda untuk melihat profil.

Anda akan dapat melihat label kustom yang Anda buat dan menganalisis waktu eksekusi berbagai bagian model Anda.

Jika Anda menggunakan Google Cloud untuk menjalankan workload, sebaiknya gunakan alat cloud-diagnostics-xprof. Fitur ini memberikan pengalaman pengumpulan dan penayangan profil yang lancar menggunakan VM yang menjalankan XProf.