Профилирование рабочих нагрузок PyTorch XLA

Оптимизация производительности — важнейшая часть создания эффективных моделей машинного обучения. Вы можете использовать инструмент профилирования XProf для измерения производительности ваших рабочих нагрузок машинного обучения. XProf позволяет собирать подробные трассировки выполнения вашей модели на устройствах XLA. Эти трассировки помогут вам выявить узкие места в производительности, понять загрузку устройства и оптимизировать ваш код.

В этом руководстве описывается процесс программного захвата трассировки из вашего скрипта PyTorch XLA и визуализации её с помощью XProf.

Программная трассировка.

Вы можете получить трассировку, добавив несколько строк кода в существующий скрипт обучения. Основным инструментом для получения трассировки является модуль torch_xla.debug.profiler , который обычно импортируется с псевдонимом xp .

1. Запустите сервер профилировщика.

Прежде чем начать трассировку, необходимо запустить сервер профилировщика. Этот сервер работает в фоновом режиме вашего скрипта и собирает данные трассировки. Запустить его можно, вызвав метод xp.start_server(<port>) в начале основного блока выполнения.

2. Определите продолжительность трассировки.

Оберните код, который вы хотите профилировать, в вызовы xp.start_trace() и xp.stop_trace() . Функция start_trace принимает путь к каталогу, где сохраняются файлы трассировки.

Обычно основной цикл обучения замыкают, чтобы охватить наиболее важные операции.

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Добавьте пользовательские метки трассировки.

По умолчанию, трассировки представляют собой низкоуровневые функции PyTorch XLA, и в них сложно ориентироваться. Вы можете добавить пользовательские метки к определенным участкам кода, используя менеджер контекста xp.Trace() . Эти метки будут отображаться в виде именованных блоков в представлении временной шкалы профилировщика, что значительно упростит идентификацию конкретных операций, таких как подготовка данных, прямой проход или шаг оптимизатора.

В следующем примере показано, как можно добавить контекст к различным частям этапа обучения.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Полный пример

В следующем примере показано, как получить трассировку из скрипта PyTorch XLA на основе файла mnist_xla.py .

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

Визуализируйте трассировку

После завершения работы скрипта файлы трассировки сохраняются в указанном вами каталоге (например, /root/logs/ ). Вы можете визуализировать эту трассировку с помощью XProf.

Вы можете запустить пользовательский интерфейс профилировщика напрямую, используя автономную команду XProf, указав ей путь к каталогу с логами:

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Для просмотра профиля перейдите по указанному URL-адресу (например, http://localhost:8791/) в вашем браузере.

Вы сможете увидеть созданные вами пользовательские метки и проанализировать время выполнения различных частей вашей модели.

Если вы используете Google Cloud для запуска своих рабочих нагрузок, мы рекомендуем инструмент cloud-diagnostics-xprof . Он обеспечивает упрощенный сбор и просмотр профилей с помощью виртуальных машин, на которых запущен XProf.