Категория: Время компиляции: HBM OOM
Эта ошибка указывает на то, что программе требуется больше высокоскоростной памяти (HBM), чем физически доступно на устройстве TPU.
Примеры сообщений об ошибках:
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
Бэкенды XLA: TPU
Обзор
XLA выполняет проверки, чтобы убедиться, что совокупный размер всех необходимых статических выделений помещается в память HBM устройства.
Компилятор управляет фиксированной емкостью HBM процессора TPU для нескольких типов выделений памяти:
- Входные и выходные данные программы: обучающие пакеты, состояния оптимизатора и т. д.
- Временные данные TPU: динамическая память, необходимая для промежуточных вычислений (например, активаций, градиентов и т. д.).
- Скомпилированный двоичный файл: машинный код для TensorCore (TC) и SparseCore (SC).
- Системные накладные расходы: зарезервированное пространство для среды выполнения XLA (например, буферы ввода на более старых поколениях TPU).
- Константы: Постоянные значения, заложенные в HLO IR, выделяются в HBM.
- Внутренние механизмы компилятора: выделение памяти на уровне программы и для каждого HLO (например, информация о маршрутизации узлов в сети).
Эта ошибка возникает, когда компилятор XLA не может разместить все указанные выше выделения памяти в памяти HBM устройства.
Отладка
Внимательно проанализируйте сообщение об ошибке и журналы, чтобы определить, какая из перечисленных ниже категорий ошибки HBM OOM лучше всего описывает вашу ошибку:
- "Использование TC Hbm: X, использование SC Hbm Y": Если ошибка явно указывает на использование HBM, то суммарное использование TensorCore (TC) + SparseCore (SC) превышает лимит HBM. → Перейти к сценарию 1. Сбалансировать использование TC и SC HBM .
- "Не хватило памяти в пространстве памяти HBM" : проверьте журналы, чтобы увидеть список самых больших выделений памяти в HBM.
- В случае наличия одного или нескольких неожиданно больших тензоров (например, > 50% от предела HBM) → Перейти к сценарию 2. Недостаток памяти из-за неожиданно больших выделений памяти .
- Если в логах отсутствуют неожиданно большие тензоры → Перейти к сценарию 3. Недостаток памяти из-за выделения агрегированных памяти .
Сценарий 1. Сбалансировать использование TC и SC HBM.
Если в сообщении об ошибке явно указано использование, например, "Использование TC Hbm: X, использование SC Hbm: Y" , это означает, что суммарное использование TensorCore (TC) + SparseCore (SC) превышает лимит HBM. Сравните эти два значения, чтобы определить узкое место:
- Высокая загрузка SparseCore
- Оптимизация использования стека HBM: потребление памяти стеком HBM масштабируется в зависимости от
feature_width,max_unique_nz_per_rowиlogical_replica_count. Вы можете уменьшить пиковое использование стека, настроив флаг--xla_sc_num_serialized_tables_to_optimize_hbm, который сериализует обработку таблиц. Это происходит за счет снижения параллелизма. - Проверьте избыточность заполнения: SparseCore выравнивает таблицы встраивания по 32 байтам (8 чисел с плавающей запятой). Таблицы с малой шириной признаков (например, < 8 чисел с плавающей запятой) приводят к значительной избыточности заполнения, что нерационально использует HBM.
- Снижение использования кучи: высокие значения параметра
maximum_parallel_iterationsувеличивают объем входных данных, предварительно загружаемых в кучу HBM. Снижение этого значения может освободить значительное количество памяти. - Проверка сегментирования: Убедитесь, что таблицы встраивания правильно сегментированы по всем чипам. См. раздел «Как ограничения преобразуются в таблицы» .
- Дополнительные идеи можно найти в статье SC: Performance and memory bottlenecks .
- Оптимизация использования стека HBM: потребление памяти стеком HBM масштабируется в зависимости от
- Высокая интенсивность использования TensorCore
- Перейдите к сценарию 2 .
- Сбалансированный
- Если ни один из компонентов по отдельности не является избыточным, но их сумма слишком велика, значит, вы достигли предела возможностей чипа. Необходимо попытаться снизить использование обоих компонентов. Следуйте рекомендациям во всех трех разделах.
Сценарий 2. Недостаток памяти из-за неожиданно большого объема выделенной памяти.
Если вы видите сообщение об ошибке "Ran out of memory in memory space HBM" и в логах присутствуют одно или несколько неожиданно больших выделений памяти (> 50% от лимита HBM), это почти никогда не связано с аппаратной нехваткой памяти. Обычно это ошибка конфигурации. Проверьте метку XLA (если она присутствует) больших выделений памяти, чтобы получить подсказки об их исходном коде JAX.
- Удалите артефакты отладки
- Использование jax.debug.print() в масштабируемых приложениях может заставить компилятор материализовать весь тензор в HBM для передачи его на ЦП, что нарушает слияние данных и увеличивает пиковое использование памяти. Удалите все оставшиеся вызовы
jax.debug.print().
- Использование jax.debug.print() в масштабируемых приложениях может заставить компилятор материализовать весь тензор в HBM для передачи его на ЦП, что нарушает слияние данных и увеличивает пиковое использование памяти. Удалите все оставшиеся вызовы
- Исправьте неэффективные формы сетки или сегментирование.
- Неправильная форма сетки или отсутствие аннотаций сегментирования могут привести к тому, что компилятор по умолчанию будет использовать репликацию , заставляя его пытаться разместить очень большие тензоры на одном чипе.
- Проверьте структуру больших выделенных ресурсов и убедитесь, что сегментирование задано и распространяется XLA корректно.
Сценарий 3. Нехватка памяти из-за агрегированного выделения памяти.
Если вы видите сообщение об ошибке «Недостаточно памяти в пространстве памяти HBM» , и в логах нет неожиданно больших тензоров, это означает, что программа исчерпала свою емкость из-за того, что суммарное количество выделений памяти превысило лимит HBM. В этом случае часто полезно визуализировать профиль использования памяти, чтобы определить конкретные буферы, которые вносят вклад в пиковое использование. См. раздел «Отладка ошибок OOM с помощью XProf» для пошагового руководства по определению источников пикового использования памяти.
После того, как вы определили несколько наиболее важных факторов, используйте следующие шаги для оптимизации использования памяти.
Сценарий 3.A. Настройка конфигурации
Часто проблему нехватки памяти можно решить с помощью следующих настроек конфигурации:
- Уменьшите размер пакета: объем памяти, необходимый для промежуточных активаций и градиентов, прямо пропорционален размеру пакета. Уменьшение размера пакета часто помогает снизить потребление памяти, хотя для поддержания стабильности модели может потребоваться перенастроить скорость обучения, момент инерции или гиперпараметры оптимизатора.
- Передача входных буферов: При выполнении вычислений JAX использует буферы на устройстве для всех входных и выходных данных. Если известно, что один из входных данных не нужен после вычислений, и если он соответствует форме и типу элемента одного из выходных данных, можно указать, что соответствующий входной буфер должен быть передан для хранения выходных данных. Это уменьшит объем памяти, необходимый для выполнения, на размер переданного буфера. Этого можно добиться, указав параметр donate_argnums в качестве аргумента при использовании
jax.jit. - Включить смешанную точность (bfloat16): Используйте bfloat16 или квантование (int8 и т. д.) для самых больших тензоров в программе, если это позволяют архитектура модели и требования к качеству. Обратите внимание, что это изменение может повлиять на поведение модели и должно быть тщательно продумано.
Микропакетирование (опционально)
Если уменьшение общего размера партии или увеличение количества чипов нецелесообразно, и размер партии на один чип еще не минимизирован, можно попробовать стратегию микропартийной обработки:
- Разделите каждую партию на
nмикропартий; - Для каждой микропартии обработайте прямой и обратный проходы;
- После этого накопите градиенты и обновите вес в целом.
Этот процесс уменьшает объем памяти активации, поскольку мы разделили каждую партию на n микропартий, так что если исходная партия имела размер M , то размер памяти активации становится M/n .
Возможные проблемы: - Этот процесс увеличивает время выполнения шага, поскольку мы выполняем несколько прямых и обратных проходов. - Если размеры модели и микропакета слишком сильно различаются, вы можете столкнуться с проблемами сходимости вашей модели.
Сценарий 3.B. Оптимизация архитектуры и сегментирования.
Если изменений в конфигурации недостаточно, топология модели может оказаться слишком большой для текущей аппаратной конфигурации.
- Используйте более новые поколения TPU: как правило, более новые TPU обеспечивают больший объем памяти HBM на чип; переходите на более новые поколения TPU, если они доступны.
- Запуск на более крупной топологии чипов: Если веса модели слишком велики для существующей топологии, можно попробовать распределить их по большему количеству чипов.
Внедрите передовые методы сегментирования:
- Изучите более продвинутые подходы к параллельной обработке данных, тензоров или конвейеров.
- Укажите подсказки по сегментированию для промежуточных значений и выходных данных.
Следует отметить, что это может привести к увеличению накладных расходов на сетевую передачу данных из-за распределения тензоров по нескольким чипам.
Используйте разгрузку хоста JAX: методы разгрузки хоста позволяют пользователю переносить большие тензоры в память центрального процессора (например, разгрузка активации и разгрузка состояния оптимизатора ). Обратите внимание, что методы разгрузки хоста могут существенно повлиять на производительность, поскольку эти операции заставят систему постоянно перемещать большие тензоры между памятью HBM TPU и оперативной памятью ЦП.
Сценарий 3.C. Проверка заполнения и выравнивания тензора.
Неэффективные формы тензоров являются распространенной, скрытой причиной ошибок нехватки памяти (OOM) на TPU. Для достижения максимальной производительности на TPU, XLA дополняет размеры тензоров — обычно до значений, кратных 128 для самого младшего измерения и 8 для второго младшего. Это дополнение влияет как на входные массивы, так и на промежуточные тензоры (временные HLO), потенциально значительно увеличивая использование памяти, особенно при малых размерах измерений. См. раздел «Расположение массивов» .
- Проверка структуры больших буферов: (На TPU v5 со стандартной компоновкой)
- При наведении курсора на буфер в программе Xprof Memory Viewer появляется карточка с подробными сведениями о буфере, включая информацию о заполнении.
- Пример : Формат
(129, 1024)может быть дополнен до(256, 1024), что приведет к почти 50% неэффективному использованию памяти. - Исправление: Форма
(128, 1024)не требует заполнения и не приводит к 0% расходованию памяти.
- Выравнивание размеров: Убедитесь, что все большие размеры тензора (размер пакета, размерность встраивания, размер скрытого слоя) кратны 128. Обратите внимание, что это изменение может повлиять на поведение модели и должно быть тщательно продумано.
Сценарий 3.D. Настройка памяти клавиш, влияющая на флаги XLA.
Ключевые параметры памяти можно настроить таким образом, чтобы обеспечить компромисс между производительностью и меньшим использованием памяти. Однако эту стратегию следует использовать только в крайнем случае, поскольку она может негативно повлиять на производительность.
Сценарий 3.E. Настройка процесса рематериализации XLA/ручное создание контрольных точек.
Если модель близка к тому, чтобы поместиться в память, вы можете использовать декоратор jax.checkpoint с jax.grad , чтобы вручную управлять тем, какие промежуточные значения сохраняются при прямом проходе, а какие пересчитываются при обратном проходе. Обратите внимание, что эта операция может повлиять на производительность, поскольку вы явно обмениваете вычислительные циклы на HBM. Для получения дополнительной информации ознакомьтесь с документацией JAX: - Сохранение контрольных точек градиента с помощью jax.checkpoint ( jax.remat ) - Управление сохраненными значениями автодифференциала с помощью jax.checkpoint (также известного как jax.remat ) - Память JAX и разгрузка хоста
В качестве альтернативы можно принудительно установить приоритет экономии памяти для прохода XLA::Rematerialization , что потенциально может привести к замедлению компиляции:
| Флаг | Описание | Влияние / Компромисс |
|---|---|---|
--xla_tpu_max_hbm_size_mib | Вручную устанавливает ограничение на размер HBM, используемый в процессе рематериализации. | Заставляет компилятор прилагать больше усилий, чтобы уместить программу в пределы, меньшие, чем фактический физический размер HBM. |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY | Сосредотачивает усилия на точках пикового использования памяти. | Может быть более эффективным для агрессивного сокращения объема памяти, чем алгоритм по умолчанию. |
--xla_tpu_rematerialization_max_block_size_limit=32 | Контролирует максимальное количество инструкций в блоке, которые могут быть рематериализованы одновременно. | Увеличение этого параметра позволяет экономить память за счет значительного увеличения времени компиляции . |
--xla_tpu_rematerialization_block_effort_factor=10.0 | Определяет объем усилий (время компиляции), затрачиваемых на поиск блоков для повторной материализации. | Более высокие значения позволяют проводить более тщательный поиск способов экономии памяти за счет увеличения времени компиляции . |
--xla_tpu_pre_fusion_remat=true | Позволяет выполнить дополнительный этап рематериализации перед этапом слияния. | Можно добиться большей экономии памяти, но это увеличит время компиляции и потенциально может повлиять на численную стабильность . |
Следует отметить, что изменение флагов XLA следует использовать только в крайнем случае, поскольку это может негативно повлиять на производительность.
Сценарий 3.F. Использование расширенных инструментов профилирования.
В руководстве по отладке ошибок нехватки памяти (OOM) с помощью XProf представлен обзор использования памяти XProf для визуализации данных об использовании HBM, отображаемых компилятором.
Этот инструмент позволяет отслеживать пиковые значения выделения памяти и время жизни буферов, что крайне важно для точного понимания того, что именно потребляет HBM в момент пиковой нагрузки. Общие сведения о настройке профилирования см. в разделе «Начало работы с Xprof и профилированием TensorBoard» .
Сводная таблица
В таблице ниже приведено краткое описание возможных мер по устранению ошибок нехватки памяти, а также информация, которая поможет вам принять решение о дальнейших действиях.
| Вмешательство | Безопасно ли это? (Изменит ли это поведение программы?) | Потенциальная выгода | Характерные признаки (действительно ли это то самое узкое место, с которым вы столкнулись?) |
|---|---|---|---|
| Использование передовых методов сегментирования | Да. Это практически никогда не влияет на численную корректность эксперимента, хотя может вызывать накладные расходы на сетевую передачу данных из-за разделения тензоров между несколькими чипами. | Значительный прирост (снижение стоимости до 256 раз) | Неожиданно большие объемы памяти, выделяемые отдельными блоками, отображаются в окне просмотра памяти (например, один тензор, реплицированный на все TPU, в 256 раз больше остальных). Активные массивы отображаются как несегментированные в хуках TensorBoard. |
| Сокращение размера партии | Нет. Это изменяет динамику обучения и обычно требует повторной настройки скорости обучения. (Примечание: Микропакетная обработка — это безопасная альтернатива, которая уменьшает объем памяти без изменения поведения). | Огромная выгода (можно сэкономить в тысячи раз). | Во время вычислений градиента не удается выделить "временные" данные. В названии операции отображается "JVP", а в профиле использования памяти обнаруживается множество тензоров, соответствующих размеру пакета. |
| Включение смешанной точности (например, Bfloat16) | Рискованно. Это изменяет точность численных расчетов, что может повлиять на результаты эксперимента или привести к тому, что модель вообще не сойдется. | Умеренный прирост производительности (обычно в 2 раза, так как использование памяти сокращается вдвое). | Программа для просмотра памяти подтверждает, что самые большие тензоры в настоящее время используют 32-битные числа с плавающей запятой ( float32 ). |
Ручное создание контрольных точек ( jax.checkpoint ) | Да. Это не меняет поведение; это просто обмен вычислительного времени (флопсов) на экономию памяти за счет пересчета тензоров вместо их хранения. | Значительный выигрыш (например, может привести к тому, что в памяти одновременно будет находиться только половина активаций). | В ходе обратного прохода память заполняется несколькими тензорами одинакового размера. Часто в названии операции присутствует "JVP". |
Предоставление входных буферов ( donate_argnums ) | Да. Обеспечивает целостность эксперимента. При неправильном применении данные не будут повреждены, а просто будет выдано понятное сообщение об ошибке. | Незначительный выигрыш (примерно 1% экономии памяти). | Нет каких-либо конкретных признаков, указывающих на это, но это считается "бесплатной победой", которую всегда стоит попробовать. |
| Изменение размеров модели | Нет. Это напрямую изменяет поведение модели. Изменение входных или выходных параметров может полностью нарушить совместимость с набором данных. | Прирост производительности зависит от того, насколько сильно уменьшены скрытые измерения или слои. | Программа просмотра памяти Xprof показывает, что большой объем памяти расходуется впустую на «заполнение», поскольку размеры массива не являются степенями двойки или кратными 128 (например, размер 2050 вместо 2048). |
| Разгрузка хоста (процессора) | Да (в численном отношении) , но нет (в отношении производительности) . Хотя математически это безопасно, это считается "пистолетом под дулом пистолета", который может вызвать серьезные узкие места в скорости из-за передачи данных между ЦП и ТПУ. | Незначительный прирост производительности (процессор имеет примерно в 3 раза больше памяти, чем TPU). | Как правило, это крайняя мера для обработки большого количества состояний оптимизатора или для ресурсоемких этапов подготовки/предварительной обработки данных. |