PyTorch XLA iş yüklerinin profilini oluşturma

Performans optimizasyonu, verimli makine öğrenimi modelleri oluşturmanın önemli bir parçasıdır. Makine öğrenimi iş yüklerinizin performansını ölçmek için XProf profil oluşturma aracını kullanabilirsiniz. XProf, modelinizin XLA cihazlarındaki yürütülmesine ilişkin ayrıntılı izlemeler yakalamanıza olanak tanır. Bu izlemeler, performans darboğazlarını belirlemenize, cihaz kullanımını anlamanıza ve kodunuzu optimize etmenize yardımcı olabilir.

Bu kılavuzda, PyTorch XLA komut dosyanızdan programatik olarak izleme yakalama ve XProf kullanarak görselleştirme süreci açıklanmaktadır.

İzi programatik olarak yakalama

Mevcut eğitim senaryonuza birkaç satır kod ekleyerek izleme yakalayabilirsiniz. İzleme yakalamak için kullanılan birincil araç torch_xla.debug.profiler modülüdür. Bu modül genellikle xp takma adıyla içe aktarılır.

1. Profiler sunucusunu başlatma

İzleme yakalayabilmek için önce profil oluşturucu sunucusunu başlatmanız gerekir. Bu sunucu, komut dosyanızın arka planında çalışır ve izleme verilerini toplar. Ana yürütme bloğunuzun başlangıcına yakın bir yerde xp.start_server(<port>) çağırarak başlatabilirsiniz.

2. İzleme süresini tanımlama

Profillendirmek istediğiniz kodu xp.start_trace() ve xp.stop_trace() çağrıları içine alın. start_trace işlevi, izleme dosyalarının kaydedildiği bir dizinin yolunu alır.

En alakalı işlemleri yakalamak için ana eğitim döngüsünü sarmak yaygın bir uygulamadır.

import torch_xla.debug.profiler as xp

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. Özel izleme etiketleri ekleme

Varsayılan olarak, yakalanan izler düşük düzeyli PyTorch XLA işlevleridir ve gezinmesi zor olabilir. xp.Trace()Bağlam yöneticisini kullanarak kodunuzun belirli bölümlerine özel etiketler ekleyebilirsiniz. Bu etiketler, profil oluşturucunun zaman çizelgesi görünümünde adlandırılmış bloklar olarak görünür. Böylece veri hazırlama, ileri geçiş veya optimize edici adımı gibi belirli işlemleri tanımlamak çok daha kolay hale gelir.

Aşağıdaki örnekte, bir eğitim adımının farklı bölümlerine nasıl bağlam ekleyebileceğiniz gösterilmektedir.

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

Eksiksiz örnek

Aşağıdaki örnekte, mnist_xla.py dosyasına göre bir PyTorch XLA komut dosyasından nasıl iz yakalanacağı gösterilmektedir.

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

İzi görselleştirme

Komut dosyanız tamamlandığında izleme dosyaları, belirttiğiniz dizine (örneğin, /root/logs/) kaydedilir. Bu izlemeyi XProf kullanarak görselleştirebilirsiniz.

Günlük dizininize yönlendirerek bağımsız XProf komutunu kullanarak profiler kullanıcı arayüzünü doğrudan başlatabilirsiniz:

$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
  Log Directory: /root/logs/
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Profili görüntülemek için tarayıcınızda sağlanan URL'ye (ör. http://localhost:8791/) gidin.

Oluşturduğunuz özel etiketleri görebilir ve modelinizin farklı bölümlerinin yürütme süresini analiz edebilirsiniz.

İş yüklerinizi çalıştırmak için Google Cloud'u kullanıyorsanız cloud-diagnostics-xprof aracını öneririz. XProf çalıştıran VM'leri kullanarak basitleştirilmiş bir profil toplama ve görüntüleme deneyimi sunar.