אופטימיזציה של הביצועים היא חלק חשוב בבנייה של מודלים יעילים של למידת מכונה. אתם יכולים להשתמש בכלי ליצירת פרופילים XProf כדי למדוד את הביצועים של עומסי העבודה של למידת המכונה. XProf מאפשר לכם לתעד עקבות מפורטים של ההרצה של המודל במכשירי XLA. הנתונים האלה יכולים לעזור לכם לזהות צווארי בקבוק בביצועים, להבין את ניצול המכשיר ולבצע אופטימיזציה של הקוד.
במדריך הזה מוסבר איך ללכוד באופן פרוגרמטי נתוני מעקב מסקריפט PyTorch XLA ולהציג אותם באמצעות XProf.
תיעוד של נתוני מעקב באופן פרוגרמטי
כדי לתעד את הנתונים, מוסיפים כמה שורות קוד לסקריפט האימון הקיים. הכלי העיקרי ללכידת נתוני מעקב הוא המודול torch_xla.debug.profiler, שבדרך כלל מייבאים אותו עם הכינוי xp.
1. הפעלת שרת הפרופיל
כדי ללכוד נתונים למעקב, צריך להפעיל את שרת הפרופיל. השרת הזה פועל ברקע של הסקריפט ואוסף את נתוני המעקב. אפשר להתחיל אותו על ידי קריאה ל-xp.start_server(<port>) קרוב לתחילת בלוק הביצוע הראשי.
2. הגדרת משך המעקב
עוטפים את הקוד שרוצים ליצור לו פרופיל בקריאות xp.start_trace() ו-xp.stop_trace(). הפונקציה start_trace מקבלת נתיב לספרייה שבה נשמרים קובצי המעקב.
מקובל לעטוף את לולאת האימון הראשית כדי לתעד את הפעולות הרלוונטיות ביותר.
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. הוספת תוויות מותאמות אישית למעקב
כברירת מחדל, העקבות שמתועדים הן פונקציות ברמה נמוכה של Pytorch XLA, ולפעמים קשה לנווט בהן. אפשר להוסיף תוויות מותאמות אישית לקטעים ספציפיים בקוד באמצעות xp.Trace() מנהל ההקשר. התוויות האלה יופיעו כבלוקים עם שמות בתצוגת ציר הזמן של כלי הפרופיל, וכך יהיה קל יותר לזהות פעולות ספציפיות כמו הכנת נתונים, העברה קדימה או שלב האופטימיזציה.
בדוגמה הבאה אפשר לראות איך מוסיפים הקשר לחלקים שונים בשלב ההכשרה.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
דוגמה מלאה
בדוגמה הבאה אפשר לראות איך ללכוד מעקב מסקריפט PyTorch XLA, על סמך הקובץ mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
הדמיה של העקבות
בסיום הסקריפט, קובצי העקבות נשמרים בספרייה שציינתם (לדוגמה, /root/logs/). אפשר להציג את העקבות האלה באמצעות XProf.
אפשר להפעיל את ממשק המשתמש של הכלי ליצירת פרופילים ישירות באמצעות הפקודה העצמאית XProf, על ידי הפניה שלה לספריית היומן:
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
בדפדפן, עוברים לכתובת ה-URL שסופקה (למשל, http://localhost:8791/) כדי לראות את הפרופיל.
תוכלו לראות את התוויות המותאמות אישית שיצרתם ולנתח את זמן הביצוע של חלקים שונים במודל.
אם אתם משתמשים ב-Google Cloud כדי להריץ את עומסי העבודה, מומלץ להשתמש בכלי cloud-diagnostics-xprof. היא מספקת חוויה יעילה של איסוף פרופילים וצפייה בהם באמצעות מכונות וירטואליות (VM) שמריצות XProf.