Optymalizacja wydajności jest kluczowym elementem tworzenia wydajnych modeli systemów uczących się. Narzędzie do profilowania XProf umożliwia pomiar wydajności zadań uczenia maszynowego. XProf umożliwia rejestrowanie szczegółowych śladów wykonywania modelu na urządzeniach XLA. Te ślady mogą pomóc w identyfikowaniu wąskich gardeł wydajności, zrozumieniu wykorzystania urządzenia i optymalizacji kodu.
W tym przewodniku opisujemy proces programowego rejestrowania śladu ze skryptu PyTorch XLA i wizualizowania go za pomocą XProf.
Automatyczne rejestrowanie śladu
Aby zarejestrować ślad, dodaj kilka wierszy kodu do istniejącego skryptu trenowania. Głównym narzędziem do rejestrowania śladu jest moduł torch_xla.debug.profiler, który jest zwykle importowany z aliasem xp.
1. Uruchamianie serwera profilera
Zanim zaczniesz rejestrować ślad, musisz uruchomić serwer profilera. Ten serwer działa w tle skryptu i zbiera dane śledzenia. Możesz go uruchomić, wywołując funkcję xp.start_server(<port>) na początku głównego bloku wykonania.
2. Określanie czasu trwania śledzenia
Owiń kod, który chcesz profilować, wywołaniami xp.start_trace() i xp.stop_trace(). Funkcja start_trace przyjmuje ścieżkę do katalogu, w którym są zapisywane pliki śledzenia.
Główną pętlę trenowania zwykle owija się, aby przechwycić najbardziej istotne operacje.
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Dodawanie niestandardowych etykiet śledzenia
Domyślnie przechwytywane ślady to funkcje Pytorch XLA niskiego poziomu, które mogą być trudne w nawigacji. Możesz dodawać etykiety własne do określonych sekcji kodu za pomocą xp.Trace()menedżera kontekstu. Te etykiety będą wyświetlane jako nazwane bloki na osi czasu profilera, co znacznie ułatwi identyfikowanie konkretnych operacji, takich jak przygotowanie danych, przejście w przód czy krok optymalizatora.
Poniższy przykład pokazuje, jak dodać kontekst do różnych części kroku szkoleniowego.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Kompletny przykład
Poniższy przykład pokazuje, jak przechwycić ślad ze skryptu PyTorch XLA na podstawie pliku mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Wizualizacja logu czasu
Po zakończeniu działania skryptu pliki śledzenia zostaną zapisane w określonym katalogu (np. /root/logs/). Możesz wizualizować ślad za pomocą XProf.
Interfejs profilera możesz uruchomić bezpośrednio za pomocą samodzielnego polecenia XProf, kierując je do katalogu logów:
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Otwórz w przeglądarce podany adres URL (np. http://localhost:8791/), aby wyświetlić profil.
Będziesz mieć możliwość wyświetlania utworzonych przez siebie etykiet niestandardowych i analizowania czasu wykonania różnych części modelu.
Jeśli do uruchamiania zadań używasz Google Cloud, zalecamy narzędzie cloud-diagnostics-xprof. Umożliwia to usprawnione zbieranie i wyświetlanie profili za pomocą maszyn wirtualnych z XProf.