Профилирование JAX на графических процессорах с помощью XProf и ML Diagnostics.

Оптимизация крупномасштабных моделей JAX на графических процессорах требует глубокого понимания узких мест производительности. Это руководство предоставляет комплексный, сквозной (E2E) рабочий процесс для запуска и профилирования рабочих нагрузок JAX на графических процессорах (таких как NVIDIA L4) с использованием Google Cloud ML Diagnostics и XProf. Используя эти инструменты, вы можете выявлять неэффективные операции, оптимизировать использование вычислительных ресурсов и ускорять процессы обучения.

Следуя этому руководству, вы узнаете, как:

  1. Внедрить простой цикл обучения JAX для профилирования.
  2. Контейнеризуйте рабочую нагрузку с соответствующей поддержкой CUDA.
  3. Разверните рабочую нагрузку в Google Kubernetes Engine (GKE) с помощью JobSet.
  4. Динамическая регистрация и визуализация профилей производительности.

Предварительные требования

Прежде чем начать, убедитесь, что у вас есть:

  • Проект Google Cloud с включенной функцией выставления счетов.
  • Кластер GKE с поддержкой графических процессоров (например, NVIDIA L4).
  • Сегмент Google Cloud Storage (GCS) для хранения профилей.
  • Установлены и настроены интерфейсы командной строки gcloud и kubectl .
  • Для доступа к GCS в вашем кластере GKE настроен идентификатор рабочей нагрузки.

Шаг 1: Настройка параметров рабочей нагрузки JAX

Сначала создайте скрипт для обучения JAX (например, train.py ). Для взаимодействия с управляемой инфраструктурой профилирования мы используем SDK google-cloud-mldiagnostics .

[!ПРЕДУПРЕЖДЕНИЕ] Приведенный ниже скрипт включает бесконечный цикл для поддержания загрузки графического процессора для демонстраций профилирования по запросу. Не забудьте вручную остановить задание или удалить ресурсы GKE после завершения, чтобы избежать ненужных расходов на оплату.

import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np

logging.basicConfig(level=logging.INFO)

def main():
    logging.info("Starting JAX training job...")

    # Coordinates multihost collective operations and healthchecks
    jax.distributed.initialize()

    logging.info(
        f"JAX initialized: process_index={jax.process_index()}, "
        f"process_count={jax.process_count()}"
    )

    # Syncs metadata with the mldiag hook & launches reverse proxy daemons
    machinelearning_run(
        name=f"jax-gpu-run-{int(time.time())}",
        configs={"learning_rate": 1e-5, "batch_size": 8192},
        project=os.environ.get("PROJECT_ID", "<your-project-id>"),
        region=os.environ.get("REGION", "us-central1"),
        gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
        on_demand_xprof=True,
    )

    key = jax.random.PRNGKey(0)
    size = 4096
    matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)

    def train_step(x):
        return jnp.dot(x, x)

    train_step = jax.jit(train_step)

    # Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
    matrix = train_step(matrix)
    matrix.block_until_ready() # Wait for compilation to complete.

    prof = xprof()
    prof.start(session_id="warmup_phase")

    for _ in range(5):
        matrix = train_step(matrix)
        matrix.block_until_ready()

    prof.stop()
    logging.info("Programmatic profile capture complete.")

    logging.info("Entering training loop. Ready for on-demand profiling...")
    try:
        while True:
            # Continuously pump steps keeping GPUs occupied for on-demand capture triggers
            matrix = train_step(matrix)
            matrix.block_until_ready() # Ensure GPU work completes before next step.
            time.sleep(0.5)
    except KeyboardInterrupt:
        logging.info("Training loop interrupted.")

if __name__ == "__main__":
    main()

Шаг 2: Контейнеризация (Dockerfile)

Создайте Dockerfile для упаковки вашего JAX-скрипта с необходимыми зависимостями CUDA и SDK для диагностики машинного обучения.

# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04

# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
    python3-pip \
    python3-venv \
    python3-dev \
    git \
    curl \
    && rm -rf /var/lib/apt/lists/*

# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.

# Upgrade pip inside the venv
RUN pip install --upgrade pip

# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
    google-cloud-mldiagnostics \
    xprof-nightly

WORKDIR /app
COPY train.py .

CMD ["python3", "train.py"]

Соберите и загрузите образ в реестр артефактов:

docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest

Шаг 3: Развертывание (манифест Kubernetes)

Разверните рабочую нагрузку, используя набор заданий GKE JobSet или стандартное задание. Чтобы платформа ML Diagnostics могла внедрять метаданные и маршрутизировать запросы профилирования, примените метку managed-mldiagnostics-gke: "true" . Для получения более подробной информации о настройке GKE для ML Diagnostics обратитесь к официальному руководству по настройке GKE .

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: jax-gpu-job
  namespace: ai-workloads
  labels:
    managed-mldiagnostics-gke: "true"
spec:
  replicatedJobs:
  - name: gpu-nodes
    replicas: 1
    template:
      spec:
        parallelism: 1
        completions: 1
        backoffLimit: 0
        template:
          metadata:
            labels:
              managed-mldiagnostics-gke: "true"
          spec:
            # Must match the GKE Service Account with Workload Identity permissions
            serviceAccountName: <your-service-account>
            hostNetwork: true
            dnsPolicy: ClusterFirstWithHostNet
            nodeSelector:
              cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
            containers:
            - name: workload
              image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
              imagePullPolicy: Always
              # Expose ports required for profile daemons
              ports:
              - containerPort: 8471 # JAX distributed coordinator port
              - containerPort: 8080 # ML Diagnostics agent/proxy port
              - containerPort: 9999 # XProf server port for on-demand profiling
              resources:
                limits:
                  nvidia.com/gpu: 1

Примените манифест:

kubectl apply -f deploy.yaml

Шаг 4: Захват и визуализация

Программный захват

Если вы включили в свой скрипт функции prof.start() / prof.stop() , эти профили автоматически загружаются в ваш бакет GCS по пути: gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/

Захват по запросу

Поскольку в machinelearning_run задано значение on_demand_xprof=True , вы можете динамически получать профили во время выполнения задания.

Подробные инструкции по использованию пользовательского интерфейса TensorBoard для запуска профилирования по запросу, выбора конкретных модулей и просмотра захваченных трассировок см. в официальной общедоступной документации: Google Cloud ML Diagnostics - On-demand profile capture .

Также можно создавать профили с помощью интерфейса командной строки gcloud, как описано в руководстве по интерфейсу командной строки ML Diagnostics .

Данная публичная документация применима как к рабочим нагрузкам на TPU, так и на GPU, управляемым ML Diagnostics.