জিপিইউ-তে বৃহৎ আকারের JAX মডেল অপ্টিমাইজ করার জন্য পারফরম্যান্সের প্রতিবন্ধকতাগুলো সম্পর্কে গভীর ধারণা থাকা প্রয়োজন। এই নির্দেশিকাটি Google Cloud ML Diagnostics এবং XProf ব্যবহার করে জিপিইউ-তে (যেমন NVIDIA L4) JAX ওয়ার্কলোড চালানো এবং প্রোফাইলিং করার জন্য একটি ব্যাপক, এন্ড-টু-এন্ড (E2E) কর্মপ্রবাহ প্রদান করে। এই টুলগুলো কাজে লাগিয়ে, আপনি অদক্ষ অপারেশনগুলো শনাক্ত করতে, কম্পিউট রিসোর্সের ব্যবহার অপ্টিমাইজ করতে এবং আপনার ট্রেনিং রানকে ত্বরান্বিত করতে পারবেন।
এই নির্দেশিকা অনুসরণ করে আপনি শিখবেন কীভাবে:
- প্রোফাইলিংয়ের জন্য একটি সাধারণ JAX ট্রেনিং লুপকে ইনস্ট্রুমেন্ট করুন।
- যথাযথ CUDA সমর্থনের মাধ্যমে ওয়ার্কলোডকে কন্টেইনারাইজ করুন।
- JobSet ব্যবহার করে Google Kubernetes Engine (GKE)-তে ওয়ার্কলোডটি ডেপ্লয় করুন।
- গতিশীলভাবে পারফরম্যান্স প্রোফাইল সংগ্রহ ও প্রদর্শন করুন।
পূর্বশর্ত
শুরু করার আগে, নিশ্চিত করুন যে আপনার কাছে নিম্নলিখিত জিনিসগুলো আছে:
- বিলিং সক্ষম একটি গুগল ক্লাউড প্রজেক্ট।
- GPU সমর্থনসহ একটি GKE ক্লাস্টার (যেমন, NVIDIA L4)।
- প্রোফাইল সংরক্ষণের জন্য একটি গুগল ক্লাউড স্টোরেজ (GCS) বাকেট।
-
gcloudএবংkubectlCLI ইনস্টল ও কনফিগার করা হয়েছে। - GCS অ্যাক্সেস করার জন্য আপনার GKE ক্লাস্টারে ওয়ার্কলোড আইডেন্টিটি কনফিগার করা হয়েছে।
ধাপ ১: JAX ওয়ার্কলোডকে ইনস্ট্রুমেন্টেশন করা
প্রথমে, একটি JAX ট্রেনিং স্ক্রিপ্ট তৈরি করুন (যেমন, train.py )। আমরা পরিচালিত প্রোফাইলিং পরিকাঠামোর সাথে সংযোগ স্থাপনের জন্য google-cloud-mldiagnostics SDK ব্যবহার করি।
[!সতর্কবার্তা] নিচের স্ক্রিপ্টটিতে অন-ডিমান্ড প্রোফাইলিং প্রদর্শনের জন্য GPU-কে ব্যস্ত রাখতে একটি অসীম লুপ অন্তর্ভুক্ত রয়েছে। অপ্রয়োজনীয় বিলিং খরচ এড়াতে, কাজ শেষ হয়ে গেলে ম্যানুয়ালি জবটি বন্ধ করুন অথবা GKE রিসোর্সগুলো মুছে ফেলুন।
import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np
logging.basicConfig(level=logging.INFO)
def main():
logging.info("Starting JAX training job...")
# Coordinates multihost collective operations and healthchecks
jax.distributed.initialize()
logging.info(
f"JAX initialized: process_index={jax.process_index()}, "
f"process_count={jax.process_count()}"
)
# Syncs metadata with the mldiag hook & launches reverse proxy daemons
machinelearning_run(
name=f"jax-gpu-run-{int(time.time())}",
configs={"learning_rate": 1e-5, "batch_size": 8192},
project=os.environ.get("PROJECT_ID", "<your-project-id>"),
region=os.environ.get("REGION", "us-central1"),
gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
on_demand_xprof=True,
)
key = jax.random.PRNGKey(0)
size = 4096
matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)
def train_step(x):
return jnp.dot(x, x)
train_step = jax.jit(train_step)
# Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
matrix = train_step(matrix)
matrix.block_until_ready() # Wait for compilation to complete.
prof = xprof()
prof.start(session_id="warmup_phase")
for _ in range(5):
matrix = train_step(matrix)
matrix.block_until_ready()
prof.stop()
logging.info("Programmatic profile capture complete.")
logging.info("Entering training loop. Ready for on-demand profiling...")
try:
while True:
# Continuously pump steps keeping GPUs occupied for on-demand capture triggers
matrix = train_step(matrix)
matrix.block_until_ready() # Ensure GPU work completes before next step.
time.sleep(0.5)
except KeyboardInterrupt:
logging.info("Training loop interrupted.")
if __name__ == "__main__":
main()
ধাপ ২: কন্টেইনারাইজেশন (ডকারফাইল)
আপনার JAX স্ক্রিপ্টটিকে প্রয়োজনীয় CUDA ডিপেন্ডেন্সি এবং ML ডায়াগনস্টিকস SDK-এর সাথে প্যাকেজ করার জন্য একটি Dockerfile তৈরি করুন।
# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04
# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
python3-pip \
python3-venv \
python3-dev \
git \
curl \
&& rm -rf /var/lib/apt/lists/*
# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.
# Upgrade pip inside the venv
RUN pip install --upgrade pip
# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
google-cloud-mldiagnostics \
xprof-nightly
WORKDIR /app
COPY train.py .
CMD ["python3", "train.py"]
ইমেজটি বিল্ড করে আর্টিফ্যাক্ট রেজিস্ট্রি-তে পুশ করুন:
docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
ধাপ ৩: ডেপ্লয়মেন্ট (কুবারনেটিস ম্যানিফেস্ট)
একটি GKE জবসেট বা স্ট্যান্ডার্ড জব ব্যবহার করে ওয়ার্কলোডটি ডেপ্লয় করুন। এমএল ডায়াগনস্টিকস প্ল্যাটফর্মকে মেটাডেটা ইনজেক্ট করতে এবং প্রোফাইল রিকোয়েস্ট রাউট করতে সক্ষম করার জন্য, managed-mldiagnostics-gke: "true" লেবেলটি প্রয়োগ করুন। এমএল ডায়াগনস্টিকসের জন্য GKE কনফিগার করার বিষয়ে আরও বিস্তারিত জানতে, অফিসিয়াল GKE সেটআপ গাইডটি দেখুন।
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: jax-gpu-job
namespace: ai-workloads
labels:
managed-mldiagnostics-gke: "true"
spec:
replicatedJobs:
- name: gpu-nodes
replicas: 1
template:
spec:
parallelism: 1
completions: 1
backoffLimit: 0
template:
metadata:
labels:
managed-mldiagnostics-gke: "true"
spec:
# Must match the GKE Service Account with Workload Identity permissions
serviceAccountName: <your-service-account>
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
containers:
- name: workload
image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
imagePullPolicy: Always
# Expose ports required for profile daemons
ports:
- containerPort: 8471 # JAX distributed coordinator port
- containerPort: 8080 # ML Diagnostics agent/proxy port
- containerPort: 9999 # XProf server port for on-demand profiling
resources:
limits:
nvidia.com/gpu: 1
ম্যানিফেস্ট প্রয়োগ করুন:
kubectl apply -f deploy.yaml
ধাপ ৪: ধারণ ও প্রদর্শন
প্রোগ্রাম্যাটিক ক্যাপচার
যদি আপনি আপনার স্ক্রিপ্টে prof.start() / prof.stop() অন্তর্ভুক্ত করে থাকেন, তাহলে সেই প্রোফাইলগুলি স্বয়ংক্রিয়ভাবে আপনার GCS বাকেটে gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/ পাথে আপলোড হয়ে যাবে।
অন-ডিমান্ড ক্যাপচার
যেহেতু machinelearning_run এ on_demand_xprof=True সেট করা আছে, তাই জবটি চলার সময়েও আপনি ডাইনামিকভাবে প্রোফাইল ক্যাপচার করতে পারবেন।
টেনসরবোর্ড UI ব্যবহার করে কীভাবে অন-ডিমান্ড প্রোফাইল ট্রিগার করতে হয়, নির্দিষ্ট পড নির্বাচন করতে হয় এবং ক্যাপচার করা ট্রেস দেখতে হয়, সে সম্পর্কে বিস্তারিত নির্দেশাবলীর জন্য, অনুগ্রহ করে অফিসিয়াল পাবলিক ডকুমেন্টেশন দেখুন: Google Cloud ML Diagnostics - On-demand profile capture ।
এমএল ডায়াগনস্টিকস সিএলআই গাইডে বর্ণিত পদ্ধতি অনুযায়ী আপনি জিক্লাউড সিএলআই ব্যবহার করেও প্রোফাইল ক্যাপচার করতে পারেন।
এই পাবলিক ডকুমেন্টেশনটি এমএল ডায়াগনস্টিকস দ্বারা পরিচালিত টিপিইউ এবং জিপিইউ উভয় ওয়ার্কলোডের ক্ষেত্রেই প্রযোজ্য।