پروفایل‌سازی JAX روی GPUها با XProf و ML Diagnostics

بهینه‌سازی مدل‌های JAX در مقیاس بزرگ روی GPUها نیازمند دید عمیق به گلوگاه‌های عملکرد است. این راهنما یک گردش کار جامع و سرتاسری (E2E) برای اجرا و پروفایل‌بندی بارهای کاری JAX روی GPUها (مانند NVIDIA L4) با استفاده از Google Cloud ML Diagnostics و XProf ارائه می‌دهد. با بهره‌گیری از این ابزارها، می‌توانید عملیات ناکارآمد را شناسایی کنید، استفاده از منابع محاسباتی را بهینه کنید و اجرای آموزش خود را تسریع بخشید.

با دنبال کردن این راهنما، یاد خواهید گرفت که چگونه:

  1. یک حلقه آموزشی ساده JAX برای پروفایل‌سازی بسازید.
  2. حجم کار را با پشتیبانی مناسب از CUDA، کانتینریزه کنید.
  3. با استفاده از JobSet، بار کاری را روی Google Kubernetes Engine (GKE) مستقر کنید.
  4. پروفایل‌های عملکرد را به صورت پویا ضبط و تجسم کنید.

پیش‌نیازها

قبل از شروع، مطمئن شوید که موارد زیر را دارید:

  • یک پروژه گوگل کلود با قابلیت پرداخت.
  • یک کلاستر GKE با پشتیبانی از پردازنده گرافیکی (مثلاً NVIDIA L4).
  • یک مخزن ذخیره‌سازی ابری گوگل (GCS) برای ذخیره پروفایل‌ها.
  • رابط‌های خط فرمان gcloud و kubectl نصب و پیکربندی شدند.
  • هویت بار کاری برای کلاستر GKE شما جهت دسترسی به GCS پیکربندی شده است.

مرحله ۱: ابزار دقیق حجم کار JAX

ابتدا، یک اسکریپت آموزشی JAX ایجاد کنید (مثلاً train.py ). ما از SDK google-cloud-mldiagnostics برای تعامل با زیرساخت پروفایل مدیریت‌شده استفاده می‌کنیم.

[!هشدار] اسکریپت زیر شامل یک حلقه بی‌نهایت است تا GPU را برای نمایش‌های پروفایلینگ بر اساس تقاضا مشغول نگه دارد. به یاد داشته باشید که پس از اتمام کار، برای جلوگیری از هزینه‌های غیرضروری، کار را به صورت دستی متوقف کنید یا منابع GKE را حذف کنید.

import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np

logging.basicConfig(level=logging.INFO)

def main():
    logging.info("Starting JAX training job...")

    # Coordinates multihost collective operations and healthchecks
    jax.distributed.initialize()

    logging.info(
        f"JAX initialized: process_index={jax.process_index()}, "
        f"process_count={jax.process_count()}"
    )

    # Syncs metadata with the mldiag hook & launches reverse proxy daemons
    machinelearning_run(
        name=f"jax-gpu-run-{int(time.time())}",
        configs={"learning_rate": 1e-5, "batch_size": 8192},
        project=os.environ.get("PROJECT_ID", "<your-project-id>"),
        region=os.environ.get("REGION", "us-central1"),
        gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
        on_demand_xprof=True,
    )

    key = jax.random.PRNGKey(0)
    size = 4096
    matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)

    def train_step(x):
        return jnp.dot(x, x)

    train_step = jax.jit(train_step)

    # Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
    matrix = train_step(matrix)
    matrix.block_until_ready() # Wait for compilation to complete.

    prof = xprof()
    prof.start(session_id="warmup_phase")

    for _ in range(5):
        matrix = train_step(matrix)
        matrix.block_until_ready()

    prof.stop()
    logging.info("Programmatic profile capture complete.")

    logging.info("Entering training loop. Ready for on-demand profiling...")
    try:
        while True:
            # Continuously pump steps keeping GPUs occupied for on-demand capture triggers
            matrix = train_step(matrix)
            matrix.block_until_ready() # Ensure GPU work completes before next step.
            time.sleep(0.5)
    except KeyboardInterrupt:
        logging.info("Training loop interrupted.")

if __name__ == "__main__":
    main()

مرحله 2: کانتینرسازی (Dockerfile)

یک Dockerfile ایجاد کنید تا اسکریپت JAX خود را به همراه وابستگی‌های CUDA مورد نیاز و ML Diagnostics SDK بسته‌بندی کنید.

# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04

# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
    python3-pip \
    python3-venv \
    python3-dev \
    git \
    curl \
    && rm -rf /var/lib/apt/lists/*

# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.

# Upgrade pip inside the venv
RUN pip install --upgrade pip

# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
    google-cloud-mldiagnostics \
    xprof-nightly

WORKDIR /app
COPY train.py .

CMD ["python3", "train.py"]

ایمیج را بسازید و به رجیستری مصنوعات منتقل کنید:

docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest

مرحله 3: استقرار (مانیفست Kubernetes)

بار کاری را با استفاده از GKE JobSet یا یک Job استاندارد مستقر کنید. برای فعال کردن پلتفرم ML Diagnostics برای تزریق فراداده و درخواست‌های پروفایل مسیر، برچسب managed-mldiagnostics-gke: "true" را اعمال کنید. برای جزئیات بیشتر در مورد پیکربندی GKE برای ML Diagnostics، به راهنمای رسمی راه‌اندازی GKE مراجعه کنید.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: jax-gpu-job
  namespace: ai-workloads
  labels:
    managed-mldiagnostics-gke: "true"
spec:
  replicatedJobs:
  - name: gpu-nodes
    replicas: 1
    template:
      spec:
        parallelism: 1
        completions: 1
        backoffLimit: 0
        template:
          metadata:
            labels:
              managed-mldiagnostics-gke: "true"
          spec:
            # Must match the GKE Service Account with Workload Identity permissions
            serviceAccountName: <your-service-account>
            hostNetwork: true
            dnsPolicy: ClusterFirstWithHostNet
            nodeSelector:
              cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
            containers:
            - name: workload
              image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
              imagePullPolicy: Always
              # Expose ports required for profile daemons
              ports:
              - containerPort: 8471 # JAX distributed coordinator port
              - containerPort: 8080 # ML Diagnostics agent/proxy port
              - containerPort: 9999 # XProf server port for on-demand profiling
              resources:
                limits:
                  nvidia.com/gpu: 1

مانیفست را اعمال کنید:

kubectl apply -f deploy.yaml

مرحله ۴: ثبت و تجسم

ضبط برنامه‌ای

اگر prof.start() / prof.stop() در اسکریپت خود گنجانده باشید، آن پروفایل‌ها به طور خودکار در مسیر gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/ در سطل GCS شما آپلود می‌شوند.

ضبط بر اساس تقاضا

از آنجا که on_demand_xprof=True در machinelearning_run تنظیم شده است، می‌توانید پروفایل‌ها را به صورت پویا در حین اجرای کار ضبط کنید.

برای دستورالعمل‌های دقیق در مورد نحوه استفاده از رابط کاربری TensorBoard برای راه‌اندازی پروفایل‌های درخواستی، انتخاب پادهای خاص و مشاهده ردپاهای ضبط‌شده، لطفاً به مستندات عمومی رسمی مراجعه کنید: Google Cloud ML Diagnostics - ضبط پروفایل درخواستی .

همچنین می‌توانید پروفایل‌ها را با استفاده از gcloud CLI همانطور که در راهنمای ML Diagnostics CLI توضیح داده شده است، ضبط کنید.

این مستندات عمومی برای هر دو بار کاری TPU و GPU که توسط ML Diagnostics مدیریت می‌شوند، اعمال می‌شود.