Дамп HLO — это текстовое представление модулей HLO на разных этапах вычислений. Он полезен для отладки, и его часто необходимо включать в отчеты об ошибках. Обычно это удобочитаемый текстовый файл , в котором перечислены инструкции HLO и их свойства. Иногда дамп модулей HLO может выглядеть следующим образом:
- HloProto: Протокольные буферные файлы, представляющие собой более структурированный, машиночитаемый формат.
- HloSnapshot : модуль HLO плюс его входные данные. Для воспроизведения HLO иногда требуются фактические входные данные, подаваемые на вычисление, а не случайные данные.
Для указания и получения дампов можно использовать флаги XLA. В большинстве случаев это можно установить с помощью переменной окружения. JAX также предлагает программный способ вывода дампа HLO.
Локальное исполнение
Использование переменных окружающей среды
Для получения дампов можно установить переменную среды XLA_FLAGS с необходимыми флагами. Это работает для JAX, TensorFlow и PyTorch/XLA.
Чтобы сохранить модули HLO и другую отладочную информацию в определенный каталог, запустите программу с флагом --xla_dump_to :
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"
Например, в качестве путей можно использовать /tmp или /tmp/xladump .
По умолчанию эта функция выводит модули HLO в текстовом виде в самом начале и в конце конвейера оптимизации.
Вы также можете явно указать формат:
- Выгрузки текста
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
- HLO-прототипы
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
- Снимки HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
- Визуализация графов с помощью сервера Graphviz (хорошо работает только для небольших графов)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
- Отображение графика в HTML-файл (хорошо работает только для небольших графиков)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"
Для больших графов можно использовать interactive_graphviz для визуализации отдельных его частей.
Вы также можете настроить дампы HLO таким образом, чтобы в качестве имен операций использовались обертки синтаксического сахара, установив флаг --xla_syntax_sugar_async_ops в true . Это может уменьшить размер дампа примерно на 20%. По умолчанию этот флаг установлен в false , и в дампе используются фактические имена операций.
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_syntax_sugar_async_ops=true"
Сбросить промежуточные проходы
Помимо стандартных предварительно оптимизированных/окончательно оптимизированных HLO, вы также можете выводить состояние HLO после определенного прохода компилятора.
XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"
Модули HLO будут выгружены для проходов, имена которых соответствуют регулярному выражению (regex). Например, вы можете просмотреть модули HLO, полученные в результате проходов, связанных с разделением SPMD, с помощью:
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"
Чтобы сохранять результаты после каждого прохода XLA (это приведет к созданию большого количества файлов), можно установить следующие параметры:
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"
Специфические параметры JAX
Программным способом в JAX
Вместо передачи флагов или переменных окружения вы также можете программно выводить HLO, используя API JAX для работы lower доступа и compile .
Получите локально неоптимизированный исходный пониженный HLO с помощью:
jax.jit(f).lower(*args).as_text('hlo')
Для сохранения данных в файлы во время компиляции HLO укажите:
compilation_args = {
'xla_dump_to': DIRECTORY_PATH,
'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
...
}
jax.jit(f).lower(*args).compile(compilation_args)
Dump jaxprs
jaxpr — это промежуточное представление трассировки программы в JAX. Чтобы вывести его содержимое, установите переменные среды:
JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr
Подробнее см. в документации JAX по экспорту и сериализации подготовленных вычислений: Отладка .
Google Colab
переменные окружающей среды
В первой выполненной ячейке вашего ноутбука (поскольку переменные среды и флаги командной строки обычно обрабатываются только один раз, например, во время импорта модуля или инициализации бэкенда XLA) добавьте описанные выше XLA_FLAGS с помощью os.environ , например:
import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"
Это позволит сохранить результаты вычислений в DIRECTORY_PATH , например /tmp . В Colab перейдите в меню «Файлы» в левой боковой панели, чтобы просмотреть и получить доступ к этому каталогу.
Вы можете использовать все флаги, упомянутые в разделе «Локальное выполнение».
Параметры, специфичные для JAX
Аналогично локальному выполнению, для интерактивного анализа в реальном времени можно напрямую вывести предварительно оптимизированный HLO вычисления:
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
c = jax.jit(f).lower(3.).compiler_ir('hlo')
print(c.as_hlo_text())
Вы также можете напрямую распечатать оптимизированный HLO вычисления:
def optimized_HLO(f, *args, platform=None):
print(jax.jit(f).lower(*args).compile().as_text())
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
optimized_HLO(f, 1.0)
Сброс всех/небольших вычислений
Если вы хотите увидеть все содержимое дампа, включая все небольшие компиляции, установите переменную среды JAX:
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0
Мозаика
Mosaic — это компилятор для бэкенда Pallas TPU и экспериментального бэкенда Pallas GPU. Для вывода результатов вычислений мозаики установите следующий флаг:
--xla_mosaic_dump_to=/tmp/mosaic_dumps
Или же установите аргументы инициализации TPU в качестве переменной окружения:
export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"
Для получения более подробной информации ознакомьтесь с документацией JAX по Pallas и Mosaic .
Больше информации с помощью HLO Dumps
Поиск подходящего вычисления
Обычно выгружается большое количество вычислительных данных. Файлы с выгруженными данными имеют явное имя, указывающее на "имя вычисления" в JAX, Tensorflow или PyTorch/XLA, что упрощает идентификацию соответствующих файлов HLO. Например:
1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt
В противном случае, вы можете использовать ripgrep для быстрого определения того, какой модуль содержит определенные символы или вычисления.
Совет: В отчеты об ошибках включите три файла с дампом до, после и назначением буфера, которые вас интересуют.
Преобразование HLO
Инструмент hlo-opt , который может преобразовывать данные между форматами HLOProto и текстовым форматом. Он полезен в случаях, когда у вас есть один формат, но для отладки нужен другой.
Научитесь им пользоваться: документация по инструментам XLA: hlo-opt .
Повтор
Вы можете запустить (воспроизвести) результаты вычислений, полученные с помощью дампа памяти, на указанном бэкенде XLA, используя фиктивные данные или снимки входных данных. Это удобный способ воспроизведения, итерации и отладки проблем в XLA.
Следующие команды используют фиктивные данные. Если у вас есть сохраненные снимки HLO, вы можете передать их вместо данных, и будут использованы данные из снимка. Чтобы по-прежнему использовать фиктивные данные при запуске снимка, передайте флаг --force_fake_data .
Процессорная часть:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
/tmp/xladump/module_4561.before_optimizations.txt
GPU-бэкэнд:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
/tmp/xladump/module_4561.before_optimizations.txt