Optimizing large-scale JAX models on GPUs requires deep visibility into performance bottlenecks. This guide provides a comprehensive, end-to-end (E2E) workflow for running and profiling JAX workloads on GPUs (such as NVIDIA L4) using Google Cloud ML Diagnostics and XProf. By leveraging these tools, you can identify inefficient operations, optimize compute resource usage, and accelerate your training runs.
By following this guide, you will learn how to:
- Instrument a simple JAX training loop for profiling.
- Containerize the workload with appropriate CUDA support.
- Deploy the workload on Google Kubernetes Engine (GKE) using JobSet.
- Capture and visualize performance profiles dynamically.
Prerequisites
Before you begin, ensure you have:
- A Google Cloud project with billing enabled.
- A GKE cluster with GPU support (e.g., NVIDIA L4).
- A Google Cloud Storage (GCS) bucket to store profiles.
gcloudandkubectlCLIs installed and configured.- Workload Identity configured for your GKE cluster to access GCS.
Step 1: Instrumenting the JAX Workload
First, create a JAX training script (e.g., train.py). We use the
google-cloud-mldiagnostics SDK to interact with the managed profiling
infrastructure.
[!WARNING] The script below includes an infinite loop to keep the GPU busy for on-demand profiling demonstrations. Remember to manually stop the job or delete the GKE resources after you are done to avoid unnecessary billing costs.
import logging
import os
import time
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import xprof
import jax
import jax.numpy as jnp
import numpy as np
logging.basicConfig(level=logging.INFO)
def main():
logging.info("Starting JAX training job...")
# Coordinates multihost collective operations and healthchecks
jax.distributed.initialize()
logging.info(
f"JAX initialized: process_index={jax.process_index()}, "
f"process_count={jax.process_count()}"
)
# Syncs metadata with the mldiag hook & launches reverse proxy daemons
machinelearning_run(
name=f"jax-gpu-run-{int(time.time())}",
configs={"learning_rate": 1e-5, "batch_size": 8192},
project=os.environ.get("PROJECT_ID", "<your-project-id>"),
region=os.environ.get("REGION", "us-central1"),
gcs_path=os.environ.get("GCS_BUCKET", "gs://<your-gcs-bucket>"),
on_demand_xprof=True,
)
key = jax.random.PRNGKey(0)
size = 4096
matrix = jax.random.normal(key, (size, size), dtype=jnp.float32)
def train_step(x):
return jnp.dot(x, x)
train_step = jax.jit(train_step)
# Triggers XLA compilation ahead of tracing steps so compilation overhead isn't profiled
matrix = train_step(matrix)
matrix.block_until_ready() # Wait for compilation to complete.
prof = xprof()
prof.start(session_id="warmup_phase")
for _ in range(5):
matrix = train_step(matrix)
matrix.block_until_ready()
prof.stop()
logging.info("Programmatic profile capture complete.")
logging.info("Entering training loop. Ready for on-demand profiling...")
try:
while True:
# Continuously pump steps keeping GPUs occupied for on-demand capture triggers
matrix = train_step(matrix)
matrix.block_until_ready() # Ensure GPU work completes before next step.
time.sleep(0.5)
except KeyboardInterrupt:
logging.info("Training loop interrupted.")
if __name__ == "__main__":
main()
Step 2: Containerization (Dockerfile)
Create a Dockerfile to package your JAX script with the required CUDA
dependencies and the ML Diagnostics SDK.
# Use an official NVIDIA CUDA base image compatible with JAX
FROM nvidia/cuda:13.2.1-cudnn-devel-ubuntu24.04
# Install Python, venv, and other OS dependencies
RUN apt-get update && apt-get install -y \
python3-pip \
python3-venv \
python3-dev \
git \
curl \
&& rm -rf /var/lib/apt/lists/*
# Set up a virtual environment and update PATH to use it implicitly
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# At this point, pip and python implicitly map to the virtual env!
# No need for --break-system-packages.
# Upgrade pip inside the venv
RUN pip install --upgrade pip
# Install JAX with CUDA support
RUN pip install --upgrade "jax[cuda13]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Install ML Diagnostics SDK and XProf tools
RUN pip install --no-cache-dir \
google-cloud-mldiagnostics \
xprof-nightly
WORKDIR /app
COPY train.py .
CMD ["python3", "train.py"]
Build and push the image to Artifact Registry:
docker build -t us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest .
docker push us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
Step 3: Deployment (Kubernetes Manifest)
Deploy the workload using a GKE JobSet or standard Job. To enable the ML
Diagnostics platform to inject metadata and route profile requests, apply the
label managed-mldiagnostics-gke: "true". For more details on configuring GKE
for ML Diagnostics, refer to the official
GKE Setup Guide.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: jax-gpu-job
namespace: ai-workloads
labels:
managed-mldiagnostics-gke: "true"
spec:
replicatedJobs:
- name: gpu-nodes
replicas: 1
template:
spec:
parallelism: 1
completions: 1
backoffLimit: 0
template:
metadata:
labels:
managed-mldiagnostics-gke: "true"
spec:
# Must match the GKE Service Account with Workload Identity permissions
serviceAccountName: <your-service-account>
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-l4 # Or other GPU
containers:
- name: workload
image: us-central1-docker.pkg.dev/<project-id>/<repo>/jax-gpu-workload:latest
imagePullPolicy: Always
# Expose ports required for profile daemons
ports:
- containerPort: 8471 # JAX distributed coordinator port
- containerPort: 8080 # ML Diagnostics agent/proxy port
- containerPort: 9999 # XProf server port for on-demand profiling
resources:
limits:
nvidia.com/gpu: 1
Apply the manifest:
kubectl apply -f deploy.yaml
Step 4: Capture & Visualization
Programmatic Capture
If you included prof.start() / prof.stop() in your script, those profiles
are automatically uploaded to your GCS bucket under the path:
gs://<your-gcs-bucket>/<run-name>/plugins/profile/<session-id>/
On-Demand Capture
Because on_demand_xprof=True is set in machinelearning_run, you can capture
profiles dynamically while the job is running.
For detailed instructions on how to use the TensorBoard UI to trigger on-demand profiles, select specific pods, and view the captured traces, please refer to the official public documentation: Google Cloud ML Diagnostics - On-demand profile capture.
You can also capture profiles using the gcloud CLI as described in the ML Diagnostics CLI Guide.
This public documentation applies to both TPU and GPU workloads managed by ML Diagnostics.