Category: Buffer allocation failure - TPU
Type: Runtime
Error log example
ValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 8.00M. That was not possible. There are 6.43M free.; (0x0x1_HBM0)
Why do these happen?
XLA:TPU runtime’s memory allocator failed to find a suitable block of memory on
the accelerator’s HBM for the requested operation. These operations are
typically user initiated buffer allocations via
jax.device_put
or allocations for program outputs. These failures stem from a couple of
reasons: - Out of Memory (OOM) - The user is trying to allocate a chunk of
memory that is larger than the total amount of free memory available on the
TPU’s HBM. - Memory Fragmentation - The allocation fails because no single
contiguous free block in the memory space is large enough to satisfy the
requested size. The total amount of free memory is sufficient for the
allocation, but it is scattered across the memory space in small, non-contiguous
blocks.
The TPU runtime has a number of mechanisms in-place to retry allocation failures including: - If there are queued deallocations, the runtime retries failed allocations, - On OOMs caused by a fragmentation the runtime can automatically trigger a defragmentation and a retry. - The TPU runtime prioritizes buffer allocations over keeping programs loaded. If a buffer allocation fails due to insufficient HBM, the system will evict loaded TPU programs until enough memory is available for the buffer.
So an error encountered after the above mitigations typically require user action.
How can a user fix their program when they do happen?
- Reduce your model's memory footprint:
- Decrease Batch Size: Reducing the batch size directly lowers memory usage.
- Parameter Sharding: For very large models, use techniques like model parallelism or sharding to distribute parameters across the HBM of multiple TPU cores or hosts.
- Shorten Sequence/Context Length: For models that operate on sequences (like language models), reducing the input sequence length can significantly decrease the memory footprint.
- Buffer Donation: Utilize framework features (such as:
jax.jit(..., donate_argnums=...)) to signal to XLA that certain input buffers can be overwritten and reused for outputs. - Optimize Checkpoint Strategy: Instead of saving the entire model state at once, consider saving only the model weights or using a sharded checkpointing strategy.
- Address Memory Layout and Padding:
- TPU memory is allocated in chunks, and padding can increase the actual size of tensors.
- Ensure no memory leaks:
- Ensure references to
jax.Arrayobjects are not being held longer than intended. Holding on tojax.Arrayobjects might prevent automatic de-allocation even after program compilation is completed.
- Ensure references to
How can a user debug these failures?
Enable the tpu_log_allocations_on_oom flag for which the allocator will dump a
detailed report of all current allocations when an OOM occurs, which can be
invaluable for debugging.