XProf এবং ML Diagnostics ব্যবহার করে GPU-তে JAX-এর প্রোফাইলিং

জিপিইউ-তে বৃহৎ আকারের JAX মডেল অপ্টিমাইজ করার জন্য পারফরম্যান্সের প্রতিবন্ধকতাগুলো সম্পর্কে গভীর ধারণা থাকা প্রয়োজন। এই নির্দেশিকাটি Google Cloud ML Diagnostics এবং XProf ব্যবহার করে জিপিইউ-তে (যেমন NVIDIA L4) JAX ওয়ার্কলোড চালানো এবং প্রোফাইলিং করার জন্য একটি ব্যাপক, এন্ড-টু-এন্ড (E2E) কর্মপ্রবাহ প্রদান করে। এই টুলগুলো কাজে লাগিয়ে, আপনি অদক্ষ অপারেশনগুলো শনাক্ত করতে, কম্পিউট রিসোর্সের ব্যবহার অপ্টিমাইজ করতে এবং আপনার ট্রেনিং রানকে ত্বরান্বিত করতে পারবেন।

এই নির্দেশিকা অনুসরণ করে আপনি শিখবেন কীভাবে:

  1. প্রোফাইলিংয়ের জন্য একটি সাধারণ JAX ট্রেনিং লুপকে ইনস্ট্রুমেন্ট করুন।
  2. যথাযথ CUDA সমর্থনের মাধ্যমে ওয়ার্কলোডকে কন্টেইনারাইজ করুন।
  3. JobSet ব্যবহার করে Google Kubernetes Engine (GKE)-তে ওয়ার্কলোডটি ডেপ্লয় করুন।
  4. গতিশীলভাবে পারফরম্যান্স প্রোফাইল সংগ্রহ ও প্রদর্শন করুন।

পূর্বশর্ত

শুরু করার আগে, নিশ্চিত করুন যে আপনার কাছে নিম্নলিখিত জিনিসগুলো আছে:

  • বিলিং সক্ষম একটি গুগল ক্লাউড প্রজেক্ট।
  • GPU সমর্থনসহ একটি GKE ক্লাস্টার (যেমন, NVIDIA L4)।
  • প্রোফাইল সংরক্ষণের জন্য একটি গুগল ক্লাউড স্টোরেজ (GCS) বাকেট।
  • gcloud এবং kubectl CLI ইনস্টল ও কনফিগার করা হয়েছে।
  • GCS অ্যাক্সেস করার জন্য আপনার GKE ক্লাস্টারে ওয়ার্কলোড আইডেন্টিটি কনফিগার করা হয়েছে।

ধাপ ১: JAX ওয়ার্কলোডকে ইনস্ট্রুমেন্টেশন করা

প্রথমে, একটি JAX ট্রেনিং স্ক্রিপ্ট তৈরি করুন (যেমন, train.py )। আমরা পরিচালিত প্রোফাইলিং পরিকাঠামোর সাথে সংযোগ স্থাপনের জন্য google-cloud-mldiagnostics SDK ব্যবহার করি।

[!সতর্কবার্তা] নিচের স্ক্রিপ্টটিতে অন-ডিমান্ড প্রোফাইলিং প্রদর্শনের জন্য GPU-কে ব্যস্ত রাখতে একটি অসীম লুপ অন্তর্ভুক্ত রয়েছে। অপ্রয়োজনীয় বিলিং খরচ এড়াতে, কাজ শেষ হয়ে গেলে ম্যানুয়ালি জবটি বন্ধ করুন অথবা GKE রিসোর্সগুলো মুছে ফেলুন।

import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np

logging.basicConfig(level=logging.INFO)

def main():
    logging.info("Starting JAX training job...")

    # Coordinates multihost collective operations and healthchecks
    jax.distributed.initialize()

    logging.info(
        f"JAX initialized: process_index={jax.process_index()}, "
        f"process_count={jax.process_count()}"
    )

    # Syncs metadata with the mldiag hook & launches reverse proxy daemons
    machinelearning_run(
        name=f"jax-gpu-run-{int(time.time())}",
        configs={"learning_rate": 1e-5, "batch_size": 8192},
        project=os.environ.get("PROJECT_ID", "<your-project-id>"),
        region=os.environ.get("REGION", "us-central1"),
        gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
        on_demand_xprof=True,
    )

    key = jax.random.PRNGKey(0)
    size = 4096
    matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)

    def train_step(x):
        return jnp.dot(x, x)

    train_step = jax.jit(train_step)

    # Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
    matrix = train_step(matrix)
    matrix.block_until_ready() # Wait for compilation to complete.

    prof = xprof()
    prof.start(session_id="warmup_phase")

    for _ in range(5):
        matrix = train_step(matrix)
        matrix.block_until_ready()

    prof.stop()
    logging.info("Programmatic profile capture complete.")

    logging.info("Entering training loop. Ready for on-demand profiling...")
    try:
        while True:
            # Continuously pump steps keeping GPUs occupied for on-demand capture triggers
            matrix = train_step(matrix)
            matrix.block_until_ready() # Ensure GPU work completes before next step.
            time.sleep(0.5)
    except KeyboardInterrupt:
        logging.info("Training loop interrupted.")

if __name__ == "__main__":
    main()

ধাপ ২: কন্টেইনারাইজেশন (ডকারফাইল)

আপনার JAX স্ক্রিপ্টটিকে প্রয়োজনীয় CUDA ডিপেন্ডেন্সি এবং ML ডায়াগনস্টিকস SDK-এর সাথে প্যাকেজ করার জন্য একটি Dockerfile তৈরি করুন।

# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04

# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
    python3-pip \
    python3-venv \
    python3-dev \
    git \
    curl \
    && rm -rf /var/lib/apt/lists/*

# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.

# Upgrade pip inside the venv
RUN pip install --upgrade pip

# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
    google-cloud-mldiagnostics \
    xprof-nightly

WORKDIR /app
COPY train.py .

CMD ["python3", "train.py"]

ইমেজটি বিল্ড করে আর্টিফ্যাক্ট রেজিস্ট্রি-তে পুশ করুন:

docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest

ধাপ ৩: ডেপ্লয়মেন্ট (কুবারনেটিস ম্যানিফেস্ট)

একটি GKE জবসেট বা স্ট্যান্ডার্ড জব ব্যবহার করে ওয়ার্কলোডটি ডেপ্লয় করুন। এমএল ডায়াগনস্টিকস প্ল্যাটফর্মকে মেটাডেটা ইনজেক্ট করতে এবং প্রোফাইল রিকোয়েস্ট রাউট করতে সক্ষম করার জন্য, managed-mldiagnostics-gke: "true" লেবেলটি প্রয়োগ করুন। এমএল ডায়াগনস্টিকসের জন্য GKE কনফিগার করার বিষয়ে আরও বিস্তারিত জানতে, অফিসিয়াল GKE সেটআপ গাইডটি দেখুন।

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: jax-gpu-job
  namespace: ai-workloads
  labels:
    managed-mldiagnostics-gke: "true"
spec:
  replicatedJobs:
  - name: gpu-nodes
    replicas: 1
    template:
      spec:
        parallelism: 1
        completions: 1
        backoffLimit: 0
        template:
          metadata:
            labels:
              managed-mldiagnostics-gke: "true"
          spec:
            # Must match the GKE Service Account with Workload Identity permissions
            serviceAccountName: <your-service-account>
            hostNetwork: true
            dnsPolicy: ClusterFirstWithHostNet
            nodeSelector:
              cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
            containers:
            - name: workload
              image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
              imagePullPolicy: Always
              # Expose ports required for profile daemons
              ports:
              - containerPort: 8471 # JAX distributed coordinator port
              - containerPort: 8080 # ML Diagnostics agent/proxy port
              - containerPort: 9999 # XProf server port for on-demand profiling
              resources:
                limits:
                  nvidia.com/gpu: 1

ম্যানিফেস্ট প্রয়োগ করুন:

kubectl apply -f deploy.yaml

ধাপ ৪: ধারণ ও প্রদর্শন

প্রোগ্রাম্যাটিক ক্যাপচার

যদি আপনি আপনার স্ক্রিপ্টে prof.start() / prof.stop() অন্তর্ভুক্ত করে থাকেন, তাহলে সেই প্রোফাইলগুলি স্বয়ংক্রিয়ভাবে আপনার GCS বাকেটে gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/ পাথে আপলোড হয়ে যাবে।

অন-ডিমান্ড ক্যাপচার

যেহেতু machinelearning_runon_demand_xprof=True সেট করা আছে, তাই জবটি চলার সময়েও আপনি ডাইনামিকভাবে প্রোফাইল ক্যাপচার করতে পারবেন।

টেনসরবোর্ড UI ব্যবহার করে কীভাবে অন-ডিমান্ড প্রোফাইল ট্রিগার করতে হয়, নির্দিষ্ট পড নির্বাচন করতে হয় এবং ক্যাপচার করা ট্রেস দেখতে হয়, সে সম্পর্কে বিস্তারিত নির্দেশাবলীর জন্য, অনুগ্রহ করে অফিসিয়াল পাবলিক ডকুমেন্টেশন দেখুন: Google Cloud ML Diagnostics - On-demand profile capture

এমএল ডায়াগনস্টিকস সিএলআই গাইডে বর্ণিত পদ্ধতি অনুযায়ী আপনি জিক্লাউড সিএলআই ব্যবহার করেও প্রোফাইল ক্যাপচার করতে পারেন।

এই পাবলিক ডকুমেন্টেশনটি এমএল ডায়াগনস্টিকস দ্বারা পরিচালিত টিপিইউ এবং জিপিইউ উভয় ওয়ার্কলোডের ক্ষেত্রেই প্রযোজ্য।