パフォーマンスの最適化は、効率的な ML モデルを構築するうえで重要な要素です。XProf プロファイリング ツールを使用すると、ML ワークロードのパフォーマンスを測定できます。XProf では、XLA デバイスでのモデルの実行の詳細なトレースをキャプチャできます。これらのトレースは、パフォーマンスのボトルネックの特定、デバイスの使用率の把握、コードの最適化に役立ちます。
このガイドでは、PyTorch XLA スクリプトからトレースをプログラムでキャプチャし、XProf を使用して可視化するプロセスについて説明します。
プログラムでトレースをキャプチャする
既存のトレーニング スクリプトに数行のコードを追加することで、トレースをキャプチャできます。トレースをキャプチャするための主なツールは torch_xla.debug.profiler モジュールです。通常、このモジュールはエイリアス xp でインポートされます。
1. プロファイラ サーバーを起動する
トレースをキャプチャする前に、プロファイラ サーバーを起動する必要があります。このサーバーはスクリプトのバックグラウンドで実行され、トレースデータを収集します。メインの実行ブロックの先頭付近で xp.start_server(<port>) を呼び出すことで起動できます。
2. トレース期間を定義する
プロファイリングするコードを xp.start_trace() 呼び出しと xp.stop_trace() 呼び出しでラップします。start_trace 関数は、トレース ファイルが保存されるディレクトリのパスを受け取ります。
関連性の最も高いオペレーションをキャプチャするため、メインのトレーニング ループをラップするのが一般的です。
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. カスタム トレースラベルを追加する
デフォルトでは、キャプチャされたトレースは低レベルの Pytorch XLA 関数であり、ナビゲートが難しい場合があります。xp.Trace() コンテキスト マネージャーを使用すると、コードの特定のセクションにカスタムラベルを追加できます。これらのラベルは、プロファイラのタイムライン ビューに名前付きブロックとして表示されるため、データ準備、フォワードパス、オプティマイザー ステップなど、特定のオペレーションを簡単に識別できます。
次の例は、トレーニング ステップのさまざまな部分にコンテキストを追加する方法を示しています。
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
サンプルコードの全文
次の例は、mnist_xla.py ファイルに基づいて、PyTorch XLA スクリプトからトレースをキャプチャする方法を示しています。
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
トレースを可視化する
スクリプトが終了すると、トレース ファイルは指定したディレクトリ(/root/logs/ など)に保存されます。このトレースは、XProf を使用して可視化できます。
ログ ディレクトリを指定して、スタンドアロンの XProf コマンドを使用すると、プロファイラ UI を直接起動できます。
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
ブラウザで指定された URL(http://localhost:8791/ など)に移動して、プロファイルを表示します。
作成したカスタムラベルを確認し、モデルのさまざまな部分の実行時間を分析できます。
Google Cloud を使用してワークロードを実行する場合は、cloud-diagnostics-xprof ツールをおすすめします。これは、XProf を実行する VM を使用してプロファイルの収集と表示を効率的に行います。