Category: Compile Time: HBM OOM
This error indicates that the program requires more High Bandwidth Memory (HBM) than is physically available on the TPU device.
Sample Error Messages:
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
XLA Backends: TPU
Overview
XLA performs checks to ensure that the aggregate size of all necessary static allocations fit in the device's HBM.
The compiler manages the TPU's fixed HBM capacity for several types of allocations:
- Program Inputs and Outputs: Training batches, optimizer states etc.
- TensorCore + SparseCore Temporaries: Dynamic memory required for intermediate calculations (e.g. activations, gradients, etc).
- Compiled Binary: The machine code for both TensorCore (TC) and SparseCore (SC).
- System Overhead: Reserved space for the XLA Runtime (e.g. infeed buffers on older TPU generations).
- Constants: Constant values embedded in the HLO IR are allocated on HBM.
- Compiler Internals: Program level and per-HLO allocations (e.g. routing info for nodes in the mesh)
This error occurs when the XLA compiler cannot fit all of the above allocations into the device HBM.
Debugging
Carefully analyze the error message and logs to determine which category of HBM OOM below best describes your error:
- TensorCore (TC) + SparseCore (SC) HBM Usage Exceeds Limit: If the error explicitly breaks down usage, e.g., "TC Hbm usage: X, SC Hbm usage Y". → Jump to Section 1. Balance TC and SC HBM usage.
- Unexpectedly Large Allocations: If the error reads "Ran out of memory in memory space HBM", check the logs for an enumeration of the largest allocations on HBM. In case, one or more unexpectedly large tensors (e.g. > 50% of HBM limit) are present → Jump to Section 2. Unexpectedly Large Allocations.
- Aggregate Allocations Exceed HBM Limit: If the error reads "Ran out of memory in memory space HBM" but no unexpectedly large tensors are present in the logs → Jump to Section 3. Aggregate Allocations Exceed HBM Limit.
Section 1. Balance TC and SC HBM usage
If the error explicitly breaks down usage, e.g., "TC Hbm usage: X, SC Hbm usage Y" compare the two values to identify the bottleneck
- High SparseCore Usage:
- Optimize HBM Stack Usage: HBM stack memory consumption scales with
feature_width,max_unique_nz_per_rowandlogical_replica_count. You can reduce peak stack usage by tuning the--xla_sc_num_serialized_tables_to_optimize_hbmflag which serializes the processing of tables. This comes at the cost of reduced parallelism. - Check Padding Overhead: SparseCore aligns embedding tables to 32B (8 floats). Tables with small feature widths (e.g., < 8 floats) incur significant padding overhead, wasting HBM.
- Reduce Heap Usage: High values for
maximum_parallel_iterationsincrease the amount of input data prefetched into the HBM heap. Lowering this value can free up significant memory. - Verify Sharding: Ensure embedding tables are properly mod-sharded across all chips. See How limits translate to tables.
- Checkout SC: Performance and memory bottlenecks for more ideas.
- Optimize HBM Stack Usage: HBM stack memory consumption scales with
- High TensorCore Usage:
- Proceed to Section 2.
- Balanced
- If neither is individually excessive but the sum is too high, you are at the chip's capacity. You must try lowering usage of both components. Follow recommendations in all three sections.
Section 2. Unexpectedly Large Allocations
If one or more unexpectedly large allocations are present in the logs (> 50% of HBM limit), it is almost never a hardware capacity issue. It is typically a configuration error. Check the XLA label (if present) of the large allocations, for hints on their JAX source code.
- Remove Debugging Artifacts:
- Using jax.debug.print()
in large-scale runs can force the compiler to materialize the full tensor in
HBM to transfer it to the CPU, breaking fusion and increasing peak memory
usage. Remove any left-over
jax.debug.print()s.
- Using jax.debug.print()
in large-scale runs can force the compiler to materialize the full tensor in
HBM to transfer it to the CPU, breaking fusion and increasing peak memory
usage. Remove any left-over
- Fix Inefficient Mesh Shapes or Sharding:
- Incorrect mesh shapes or missing sharding annotations can cause the compiler to default to Replication - forcing the compiler to try to fit really large tensors on a single chip
- Check the shapes of the large allocations and verify sharding is correctly specified and propagated by XLA.
Section 3. Aggregate Allocations Exceed HBM Limit
If the program runs out of capacity due to the aggregate sum of allocations exceeding the HBM limit, it is often helpful to visualize the memory profile to identify the specific buffers contributing to the peak usage. See Debug OOM errors with XProf for a step-by-step guide on identifying peak memory contributors.
Once you have identified some of the top contributors, use the following steps to optimize the memory footprint.
A. Check tensor padding and alignment
Inefficient tensor shapes are a common, silent cause of OOMs on TPUs. To get peak performance on TPU's, XLA pads tensor dimensions—typically to multiples of 128 for the minor-most dimension and 8 for the second-minor. This padding affects both input arrays and intermediate tensors (HLO temporaries), potentially inflating memory usage significantly, especially with small dimension sizes. See Array Layouts.
- Audit shapes of large buffers: (On TPU v5 with default layouts)
- Hovering over a buffer in Xprof Memory Viewer brings up the buffer details card which contains buffer details including padding information.
- Example: A shape of
(129, 1024)might be padded to(256, 1024), resulting in nearly 50% memory waste. - Correction: A shape of
(128, 1024)requires no padding and incurs 0% memory waste.
- Align dimensions: Ensure all large tensor dimensions (batch size, embedding dimension, hidden size) are multiples of 128.
B. Adjust configuration
You can often resolve OOMs with these configuration adjustments:
- Reduce Batch Size: The memory needed for intermediate activations and gradients is directly proportional to the batch size. Reducing the batch size can often help reduce memory usage.
- Donate Input Buffers: When using
jax.jit, specify donate_argnums for your model parameters. This allows XLA to overwrite the input memory with the output. - Enable Mixed Precision (bfloat16): Use bfloat16 or quantization (int8 etc) for the largest tensors in the program if the model architecture and quality requirements allow.
C. Optimize architecture and sharding
If configuration changes are insufficient, the model topology might be too large for the current hardware setup.
- Use Newer TPU Generations: Newer TPUs generally offer more HBM per chip; switch to newer TPU generations if available.
- Run on a larger chip topology: If the model weights are too large for the existing topology, you can try sharding them across more chips.
- Implement advanced sharding techniques:
- Explore more advanced data, tensor, or pipeline parallelism approaches.
- Specify sharding hints for intermediate values and outputs.
- Use JAX Host Offloading: Offload large tensors to the host CPU memory. e.g. activation offloading and optimizer state offloading.
D. Tune key memory impacting XLA flags:
Key memory flags can be tuned to trade-off performance for lower memory usage. But these should be used as a last resort measure because it can adversely affect performance.
E. Tune XLA Rematerialization Pass / Manual Checkpointing
If the model is close to fitting into memory, you can force the
XLA::Rematerialization pass to prioritize memory savings, potentially at the
cost of slower compilations:
| Flag | Description | Impact / Trade-off |
|---|---|---|
--xla_tpu_max_hbm_size_mib |
Manually sets the limit on HBM size used by the Rematerialization pass. | Forces the compiler to work harder to fit the program into a limit smaller than the actual physical HBM. |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY |
Focuses efforts at the points of peak memory usage. | Can be more efficient for aggressive memory reduction than the default algorithm. |
--xla_tpu_rematerialization_max_block_size_limit=32 |
Controls the maximum number of instructions in a block that can be rematerialized at once. | Increasing this allows for memory savings at the cost of significantly increases compile time. |
--xla_tpu_rematerialization_block_effort_factor=10.0 |
Defines the amount of effort (compile time) spent searching for blocks to rematerialize. | Higher values allow a more exhaustive search for memory savings at the cost of increased compile times. |
--xla_tpu_pre_fusion_remat=true |
Enables an additional Rematerialization pass before the fusion pass. | Can find more memory savings, but increases compile times and may potentially impact numerical stability. |
Alternatively, use the
jax.checkpoint
decorator with jax.grad to manually control which intermediates are saved on
the forward pass versus recomputed on the backward pass, trading compute cycles
for HBM.
F. Use advanced profiling tools
Debug OOM errors with XProf provides a tutorial on using the XProf Memory Viewer to visualize the compiler's view of HBM usage.
This tool allows you to see peak memory allocation and buffer lifetimes, which is crucial for understanding exactly what consumes HBM at the point of peak utilization. For general profiling setup, see Getting started with Xprof and TensorBoard Profiling.