Pengoptimalan performa adalah bagian penting dalam membangun model machine learning yang efisien. Anda dapat menggunakan alat pembuatan profil XProf untuk mengukur performa workload machine learning Anda. XProf memungkinkan Anda merekam rekaman aktivitas mendetail dari eksekusi model di perangkat XLA. Rekaman aktivitas ini dapat membantu Anda mengidentifikasi hambatan performa, memahami pemanfaatan perangkat, dan mengoptimalkan kode.
Panduan ini menjelaskan proses pengambilan rekaman aktivitas secara terprogram dari skrip PyTorch XLA dan visualisasi menggunakan XProf.
Merekam aktivitas secara terprogram
Anda dapat merekam rekaman aktivitas dengan menambahkan beberapa baris kode ke skrip pelatihan yang ada. Alat utama untuk merekam aktivitas adalah modul torch_xla.debug.profiler, yang biasanya diimpor dengan alias xp.
1. Mulai server profiler
Sebelum dapat merekam aktivitas, Anda harus memulai server profiler. Server
ini berjalan di latar belakang skrip Anda dan mengumpulkan data rekaman aktivitas. Anda dapat memulainya dengan memanggil xp.start_server(<port>) di dekat awal blok eksekusi utama.
2. Menentukan durasi rekaman aktivitas
Gabungkan kode yang ingin Anda buat profilnya dalam panggilan xp.start_trace() dan
xp.stop_trace(). Fungsi start_trace mengambil jalur ke direktori tempat file rekaman aktivitas disimpan.
Praktik umumnya adalah membungkus loop pelatihan utama untuk merekam operasi yang paling relevan.
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Menambahkan label rekaman aktivitas kustom
Secara default, rekaman aktivitas yang diambil adalah fungsi Pytorch XLA tingkat rendah dan mungkin sulit dinavigasi. Anda dapat menambahkan label kustom ke bagian tertentu pada kode menggunakan pengelola konteks xp.Trace(). Label ini akan muncul sebagai blok bernama
di tampilan linimasa profiler, sehingga memudahkan identifikasi operasi tertentu
seperti penyiapan data, forward pass, atau langkah pengoptimal.
Contoh berikut menunjukkan cara menambahkan konteks ke berbagai bagian langkah pelatihan.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Contoh lengkap
Contoh berikut menunjukkan cara merekam aktivitas dari skrip PyTorch XLA,
berdasarkan file mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Memvisualisasikan rekaman aktivitas
Setelah skrip Anda selesai, file rekaman aktivitas akan disimpan di direktori yang Anda tentukan (misalnya, /root/logs/). Anda dapat memvisualisasikan rekaman aktivitas ini menggunakan XProf.
Anda dapat meluncurkan UI profiler secara langsung menggunakan perintah XProf mandiri dengan mengarahkan perintah tersebut ke direktori log Anda:
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Buka URL yang diberikan (misalnya, http://localhost:8791/) di browser Anda untuk melihat profil.
Anda akan dapat melihat label kustom yang Anda buat dan menganalisis waktu eksekusi berbagai bagian model Anda.
Jika Anda menggunakan Google Cloud untuk menjalankan workload, sebaiknya gunakan alat cloud-diagnostics-xprof. Fitur ini memberikan pengalaman pengumpulan dan penayangan profil yang lancar menggunakan VM yang menjalankan XProf.