بهینهسازی مدلهای JAX در مقیاس بزرگ روی GPUها نیازمند دید عمیق به گلوگاههای عملکرد است. این راهنما یک گردش کار جامع و سرتاسری (E2E) برای اجرا و پروفایلبندی بارهای کاری JAX روی GPUها (مانند NVIDIA L4) با استفاده از Google Cloud ML Diagnostics و XProf ارائه میدهد. با بهرهگیری از این ابزارها، میتوانید عملیات ناکارآمد را شناسایی کنید، استفاده از منابع محاسباتی را بهینه کنید و اجرای آموزش خود را تسریع بخشید.
با دنبال کردن این راهنما، یاد خواهید گرفت که چگونه:
- یک حلقه آموزشی ساده JAX برای پروفایلسازی بسازید.
- حجم کار را با پشتیبانی مناسب از CUDA، کانتینریزه کنید.
- با استفاده از JobSet، بار کاری را روی Google Kubernetes Engine (GKE) مستقر کنید.
- پروفایلهای عملکرد را به صورت پویا ضبط و تجسم کنید.
پیشنیازها
قبل از شروع، مطمئن شوید که موارد زیر را دارید:
- یک پروژه گوگل کلود با قابلیت پرداخت.
- یک کلاستر GKE با پشتیبانی از پردازنده گرافیکی (مثلاً NVIDIA L4).
- یک مخزن ذخیرهسازی ابری گوگل (GCS) برای ذخیره پروفایلها.
- رابطهای خط فرمان
gcloudوkubectlنصب و پیکربندی شدند. - هویت بار کاری برای کلاستر GKE شما جهت دسترسی به GCS پیکربندی شده است.
مرحله ۱: ابزار دقیق حجم کار JAX
ابتدا، یک اسکریپت آموزشی JAX ایجاد کنید (مثلاً train.py ). ما از SDK google-cloud-mldiagnostics برای تعامل با زیرساخت پروفایل مدیریتشده استفاده میکنیم.
[!هشدار] اسکریپت زیر شامل یک حلقه بینهایت است تا GPU را برای نمایشهای پروفایلینگ بر اساس تقاضا مشغول نگه دارد. به یاد داشته باشید که پس از اتمام کار، برای جلوگیری از هزینههای غیرضروری، کار را به صورت دستی متوقف کنید یا منابع GKE را حذف کنید.
import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np
logging.basicConfig(level=logging.INFO)
def main():
logging.info("Starting JAX training job...")
# Coordinates multihost collective operations and healthchecks
jax.distributed.initialize()
logging.info(
f"JAX initialized: process_index={jax.process_index()}, "
f"process_count={jax.process_count()}"
)
# Syncs metadata with the mldiag hook & launches reverse proxy daemons
machinelearning_run(
name=f"jax-gpu-run-{int(time.time())}",
configs={"learning_rate": 1e-5, "batch_size": 8192},
project=os.environ.get("PROJECT_ID", "<your-project-id>"),
region=os.environ.get("REGION", "us-central1"),
gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
on_demand_xprof=True,
)
key = jax.random.PRNGKey(0)
size = 4096
matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)
def train_step(x):
return jnp.dot(x, x)
train_step = jax.jit(train_step)
# Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
matrix = train_step(matrix)
matrix.block_until_ready() # Wait for compilation to complete.
prof = xprof()
prof.start(session_id="warmup_phase")
for _ in range(5):
matrix = train_step(matrix)
matrix.block_until_ready()
prof.stop()
logging.info("Programmatic profile capture complete.")
logging.info("Entering training loop. Ready for on-demand profiling...")
try:
while True:
# Continuously pump steps keeping GPUs occupied for on-demand capture triggers
matrix = train_step(matrix)
matrix.block_until_ready() # Ensure GPU work completes before next step.
time.sleep(0.5)
except KeyboardInterrupt:
logging.info("Training loop interrupted.")
if __name__ == "__main__":
main()
مرحله 2: کانتینرسازی (Dockerfile)
یک Dockerfile ایجاد کنید تا اسکریپت JAX خود را به همراه وابستگیهای CUDA مورد نیاز و ML Diagnostics SDK بستهبندی کنید.
# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04
# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
python3-pip \
python3-venv \
python3-dev \
git \
curl \
&& rm -rf /var/lib/apt/lists/*
# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.
# Upgrade pip inside the venv
RUN pip install --upgrade pip
# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
google-cloud-mldiagnostics \
xprof-nightly
WORKDIR /app
COPY train.py .
CMD ["python3", "train.py"]
ایمیج را بسازید و به رجیستری مصنوعات منتقل کنید:
docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
مرحله 3: استقرار (مانیفست Kubernetes)
بار کاری را با استفاده از GKE JobSet یا یک Job استاندارد مستقر کنید. برای فعال کردن پلتفرم ML Diagnostics برای تزریق فراداده و درخواستهای پروفایل مسیر، برچسب managed-mldiagnostics-gke: "true" را اعمال کنید. برای جزئیات بیشتر در مورد پیکربندی GKE برای ML Diagnostics، به راهنمای رسمی راهاندازی GKE مراجعه کنید.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: jax-gpu-job
namespace: ai-workloads
labels:
managed-mldiagnostics-gke: "true"
spec:
replicatedJobs:
- name: gpu-nodes
replicas: 1
template:
spec:
parallelism: 1
completions: 1
backoffLimit: 0
template:
metadata:
labels:
managed-mldiagnostics-gke: "true"
spec:
# Must match the GKE Service Account with Workload Identity permissions
serviceAccountName: <your-service-account>
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
containers:
- name: workload
image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
imagePullPolicy: Always
# Expose ports required for profile daemons
ports:
- containerPort: 8471 # JAX distributed coordinator port
- containerPort: 8080 # ML Diagnostics agent/proxy port
- containerPort: 9999 # XProf server port for on-demand profiling
resources:
limits:
nvidia.com/gpu: 1
مانیفست را اعمال کنید:
kubectl apply -f deploy.yaml
مرحله ۴: ثبت و تجسم
ضبط برنامهای
اگر prof.start() / prof.stop() در اسکریپت خود گنجانده باشید، آن پروفایلها به طور خودکار در مسیر gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/ در سطل GCS شما آپلود میشوند.
ضبط بر اساس تقاضا
از آنجا که on_demand_xprof=True در machinelearning_run تنظیم شده است، میتوانید پروفایلها را به صورت پویا در حین اجرای کار ضبط کنید.
برای دستورالعملهای دقیق در مورد نحوه استفاده از رابط کاربری TensorBoard برای راهاندازی پروفایلهای درخواستی، انتخاب پادهای خاص و مشاهده ردپاهای ضبطشده، لطفاً به مستندات عمومی رسمی مراجعه کنید: Google Cloud ML Diagnostics - ضبط پروفایل درخواستی .
همچنین میتوانید پروفایلها را با استفاده از gcloud CLI همانطور که در راهنمای ML Diagnostics CLI توضیح داده شده است، ضبط کنید.
این مستندات عمومی برای هر دو بار کاری TPU و GPU که توسط ML Diagnostics مدیریت میشوند، اعمال میشود.