Hồ sơ khối lượng công việc PyTorch XLA

Tối ưu hoá hiệu suất là một phần quan trọng trong việc xây dựng các mô hình học máy hiệu quả. Bạn có thể sử dụng công cụ phân tích hiệu suất XProf để đo lường hiệu suất của các khối lượng công việc học máy. XProf cho phép bạn ghi lại các dấu vết chi tiết về quá trình thực thi mô hình trên các thiết bị XLA. Các dấu vết này có thể giúp bạn xác định nút thắt cổ chai về hiệu suất, hiểu rõ mức sử dụng thiết bị và tối ưu hoá mã của bạn.

Hướng dẫn này mô tả quy trình tự động hoá việc ghi lại dấu vết từ tập lệnh PyTorch XLA và trực quan hoá bằng XProf.

Ghi lại dấu vết theo phương thức lập trình

Bạn có thể ghi lại dấu vết bằng cách thêm một vài dòng mã vào tập lệnh huấn luyện hiện có. Công cụ chính để ghi lại dấu vết là mô-đun torch_xla.debug.profiler. Mô-đun này thường được nhập bằng bí danh xp.

1. Khởi động máy chủ trình phân tích tài nguyên

Trước khi có thể ghi lại dấu vết, bạn cần khởi động máy chủ hồ sơ. Máy chủ này chạy ở chế độ nền của tập lệnh và thu thập dữ liệu theo dõi. Bạn có thể bắt đầu bằng cách gọi xp.start_server(<port>) gần đầu khối thực thi chính.

2. Xác định thời lượng dấu vết

Gói mã mà bạn muốn lập hồ sơ trong các lệnh gọi xp.start_trace()xp.stop_trace(). Hàm start_trace lấy một đường dẫn đến thư mục nơi các tệp theo dõi được lưu.

Bạn nên bao bọc vòng lặp huấn luyện chính để nắm bắt các thao tác phù hợp nhất.

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Thêm nhãn dấu vết tuỳ chỉnh

Theo mặc định, các dấu vết được ghi lại là các hàm Pytorch XLA cấp thấp và có thể khó điều hướng. Bạn có thể thêm nhãn tuỳ chỉnh vào các phần cụ thể trong mã bằng cách sử dụng trình quản lý bối cảnh xp.Trace(). Các nhãn này sẽ xuất hiện dưới dạng các khối được đặt tên trong chế độ xem dòng thời gian của trình phân tích tài nguyên, giúp bạn dễ dàng xác định các thao tác cụ thể như chuẩn bị dữ liệu, truyền xuôi hoặc bước trình tối ưu hoá.

Ví dụ sau đây cho thấy cách bạn có thể thêm bối cảnh vào các phần khác nhau của một bước huấn luyện.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Ví dụ đầy đủ

Ví dụ sau đây cho thấy cách ghi lại dấu vết từ một tập lệnh PyTorch XLA, dựa trên tệp mnist_xla.py.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Trực quan hoá dấu vết

Khi tập lệnh của bạn hoàn tất, các tệp theo dõi sẽ được lưu trong thư mục mà bạn đã chỉ định (ví dụ: /root/logs/). Bạn có thể trực quan hoá dấu vết này bằng XProf.

Bạn có thể chạy giao diện người dùng của trình phân tích tài nguyên trực tiếp bằng lệnh XProf độc lập bằng cách trỏ lệnh này đến thư mục nhật ký:

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Chuyển đến URL được cung cấp (ví dụ: http://localhost:8791/) trong trình duyệt để xem hồ sơ.

Bạn sẽ có thể xem các nhãn tuỳ chỉnh mà mình đã tạo và phân tích thời gian thực thi của các phần khác nhau trong mô hình.

Nếu sử dụng Google Cloud để chạy tải công việc, bạn nên dùng công cụ cloud-diagnostics-xprof. Công cụ này mang đến trải nghiệm xem và thu thập hồ sơ tinh giản bằng cách sử dụng các máy ảo chạy XProf.