สร้างโปรไฟล์เวิร์กโหลด PyTorch XLA

การเพิ่มประสิทธิภาพเป็นส่วนสําคัญของการสร้างโมเดลแมชชีนเลิร์นนิงที่มีประสิทธิภาพ คุณสามารถใช้เครื่องมือสร้างโปรไฟล์ XProf เพื่อวัดประสิทธิภาพของภาระงานแมชชีนเลิร์นนิงได้ XProf ช่วยให้คุณบันทึกการติดตามโดยละเอียดของการดำเนินการโมเดลในอุปกรณ์ XLA ได้ การติดตามเหล่านี้จะช่วยให้คุณระบุ จุดคอขวดด้านประสิทธิภาพ ทำความเข้าใจการใช้อุปกรณ์ และเพิ่มประสิทธิภาพโค้ดได้

คู่มือนี้อธิบายกระบวนการบันทึกการติดตามจากสคริปต์ PyTorch XLA และแสดงภาพโดยใช้ XProf โดยอัตโนมัติ

บันทึกการติดตามแบบเป็นโปรแกรม

คุณบันทึกการติดตามได้โดยเพิ่มโค้ด 2-3 บรรทัดลงในสคริปต์การฝึกที่มีอยู่ เครื่องมือหลักในการบันทึกการติดตามคือโมดูล torch_xla.debug.profiler ซึ่งโดยปกติจะนำเข้าด้วยนามแฝง xp

1. เริ่มเซิร์ฟเวอร์โปรไฟล์

คุณต้องเริ่มเซิร์ฟเวอร์โปรไฟล์ก่อนจึงจะบันทึกการติดตามได้ เซิร์ฟเวอร์นี้ทำงานในเบื้องหลังของสคริปต์และรวบรวมข้อมูลการติดตาม คุณ สามารถเริ่มต้นได้โดยการเรียกใช้ xp.start_server(<port>) ใกล้กับจุดเริ่มต้นของ บล็อกการดำเนินการหลัก

2. กำหนดระยะเวลาการติดตาม

ห่อหุ้มโค้ดที่ต้องการสร้างโปรไฟล์ไว้ภายในการเรียกใช้ xp.start_trace() และ xp.stop_trace() ฟังก์ชัน start_trace จะใช้เส้นทางไปยังไดเรกทอรี ที่บันทึกไฟล์การติดตาม

แนวทางปฏิบัติทั่วไปคือการรวมลูปการฝึกหลักเพื่อบันทึกการดำเนินการที่เกี่ยวข้องมากที่สุด

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. เพิ่มป้ายกำกับร่องรอยที่กำหนดเอง

โดยค่าเริ่มต้น การติดตามที่บันทึกไว้จะเป็นฟังก์ชัน Pytorch XLA ระดับต่ำและอาจ นำทางได้ยาก คุณเพิ่มป้ายกำกับที่กำหนดเองลงในส่วนที่เฉพาะเจาะจงของโค้ดได้ โดยใช้เครื่องมือจัดการบริบท xp.Trace() ป้ายกำกับเหล่านี้จะปรากฏเป็นบล็อกที่มีชื่อ ในมุมมองไทม์ไลน์ของโปรไฟล์เลอร์ ซึ่งช่วยให้ระบุการดำเนินการที่เฉพาะเจาะจงได้ง่ายขึ้นมาก เช่น การเตรียมข้อมูล การส่งต่อ หรือขั้นตอนของเครื่องมือเพิ่มประสิทธิภาพ

ตัวอย่างต่อไปนี้แสดงวิธีเพิ่มบริบทให้กับส่วนต่างๆ ของขั้นตอนการฝึก

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

ตัวอย่างที่สมบูรณ์

ตัวอย่างต่อไปนี้แสดงวิธีกำหนดการติดตามจากสคริปต์ PyTorch XLA โดยอิงตามไฟล์ mnist_xla.py

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

แสดงภาพการติดตาม

เมื่อสคริปต์ทำงานเสร็จแล้ว ระบบจะบันทึกไฟล์การติดตามในไดเรกทอรีที่คุณ ระบุ (เช่น /root/logs/) คุณสามารถแสดงภาพการติดตามนี้ได้โดยใช้ XProf

คุณเปิดใช้ UI ของโปรไฟล์เลอร์ได้โดยตรงโดยใช้คำสั่ง XProf แบบสแตนด์อโลนโดย ชี้ไปยังไดเรกทอรีบันทึก

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

ไปที่ URL ที่ระบุ (เช่น http://localhost:8791/) ในเบราว์เซอร์ เพื่อดูโปรไฟล์

คุณจะเห็นป้ายกำกับที่กำหนดเองซึ่งคุณสร้างขึ้น และวิเคราะห์เวลาในการดำเนินการ ของส่วนต่างๆ ในโมเดล

หากคุณใช้ Google Cloud เพื่อเรียกใช้ภาระงาน เราขอแนะนำให้ใช้เครื่องมือ cloud-diagnostics-xprof โดยจะมอบประสบการณ์การรวบรวมและดูโปรไฟล์ที่มีประสิทธิภาพโดยใช้ VM ที่เรียกใช้ XProf