La optimización del rendimiento es una parte fundamental de la compilación de modelos de aprendizaje automático eficientes. Puedes usar la herramienta de generación de perfiles XProf para medir el rendimiento de tus cargas de trabajo de aprendizaje automático. XProf te permite capturar registros detallados de la ejecución de tu modelo en dispositivos XLA. Estos registros pueden ayudarte a identificar cuellos de botella en el rendimiento, comprender el uso del dispositivo y optimizar tu código.
En esta guía, se describe el proceso para capturar de forma programática un registro de tu secuencia de comandos de PyTorch XLA y visualizarlo con XProf.
Cómo capturar un registro de forma programática
Para capturar un seguimiento, agrega algunas líneas de código a tu secuencia de comandos de entrenamiento
existente. La herramienta principal para capturar un seguimiento es el módulo torch_xla.debug.profiler,
que suele importarse con el alias xp.
1. Inicia el servidor del generador de perfiles
Antes de capturar un seguimiento, debes iniciar el servidor del generador de perfiles. Este
servidor se ejecuta en segundo plano en tu secuencia de comandos y recopila los datos de seguimiento. Puedes
iniciarlo llamando a xp.start_server(<port>) cerca del comienzo de tu
bloque de ejecución principal.
2. Define la duración del seguimiento
Encapsula el código que deseas analizar dentro de las llamadas xp.start_trace() y
xp.stop_trace(). La función start_trace toma una ruta de acceso a un directorio
en el que se guardan los archivos de seguimiento.
Es una práctica común encapsular el bucle de entrenamiento principal para capturar las operaciones más relevantes.
import torch_xla.debug.profiler as xp
# The directory where the trace files are stored.
log_dir = '/root/logs/'
# Start tracing
xp.start_trace(log_dir)
# ... your training loop or other code to be profiled ...
train_mnist()
# Stop tracing
xp.stop_trace()
3. Agrega etiquetas de seguimiento personalizadas
De forma predeterminada, los seguimientos capturados son funciones de Pytorch XLA de bajo nivel y
pueden ser difíciles de navegar. Puedes agregar etiquetas personalizadas a secciones específicas de tu código
con el administrador de contexto xp.Trace(). Estas etiquetas aparecerán como bloques
con nombre en la vista de línea de tiempo del generador de perfiles, lo que facilitará la identificación de operaciones específicas,
como la preparación de datos, el pase hacia delante o el paso del optimizador.
En el siguiente ejemplo, se muestra cómo puedes agregar contexto a diferentes partes de un paso de entrenamiento.
def forward(self, x):
# This entire block will be labeled 'forward' in the trace
with xp.Trace('forward'):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
with torch_xla.step():
with xp.Trace('train_step_data_prep_and_forward'):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
with xp.Trace('train_step_loss_and_backward'):
loss = loss_fn(output, target)
loss.backward()
with xp.Trace('train_step_optimizer_step_host'):
optimizer.step()
Ejemplo completo
En el siguiente ejemplo, se muestra cómo capturar un registro de un script de PyTorch XLA basado en el archivo mnist_xla.py.
import torch
import torch.optim as optim
from torchvision import datasets, transforms
# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
def train_mnist():
# ... (model definition and data loading code) ...
print("Starting training...")
# ... (training loop as defined in the previous section) ...
print("Training finished!")
if __name__ == '__main__':
# 1. Start the profiler server
server = xp.start_server(9012)
# 2. Start capturing the trace and define the output directory
xp.start_trace('/root/logs/')
# Run the training function that contains custom trace labels
train_mnist()
# 3. Stop the trace
xp.stop_trace()
Visualiza el seguimiento
Cuando finalice la secuencia de comandos, los archivos de registro se guardarán en el directorio que especificaste (por ejemplo, /root/logs/). Puedes visualizar este registro con XProf.
Puedes iniciar la IU del generador de perfiles directamente con el comando independiente de XProf apuntándolo a tu directorio de registros:
$ xprof --port=8791 /root/logs/
Attempting to start XProf server:
Log Directory: /root/logs/
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Navega a la URL proporcionada (p.ej., http://localhost:8791/) en tu navegador para ver el perfil.
Podrás ver las etiquetas personalizadas que creaste y analizar el tiempo de ejecución de diferentes partes de tu modelo.
Si usas Google Cloud para ejecutar tus cargas de trabajo, te recomendamos la herramienta cloud-diagnostics-xprof. Proporciona una experiencia optimizada de recopilación y visualización de perfiles con VMs que ejecutan XProf.