Дамп вычислений HLO

Дамп HLO — это текстовое представление модулей HLO на разных этапах вычислений. Он полезен для отладки, и его часто необходимо включать в отчеты об ошибках. Обычно это удобочитаемый текстовый файл , в котором перечислены инструкции HLO и их свойства. Иногда дамп модулей HLO может выглядеть следующим образом:

  • HloProto: Протокольные буферные файлы, представляющие собой более структурированный, машиночитаемый формат.
  • HloSnapshot : модуль HLO плюс его входные данные. Для воспроизведения HLO иногда требуются фактические входные данные, подаваемые на вычисление, а не случайные данные.

Для указания и получения дампов можно использовать флаги XLA. В большинстве случаев это можно установить с помощью переменной окружения. JAX также предлагает программный способ вывода дампа HLO.

Локальное исполнение

Использование переменных окружающей среды

Для получения дампов можно установить переменную среды XLA_FLAGS с необходимыми флагами. Это работает для JAX, TensorFlow и PyTorch/XLA.

Чтобы сохранить модули HLO и другую отладочную информацию в определенный каталог, запустите программу с флагом --xla_dump_to :

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

Например, в качестве путей можно использовать /tmp или /tmp/xladump .

По умолчанию эта функция выводит модули HLO в текстовом виде в самом начале и в конце конвейера оптимизации.

Вы также можете явно указать формат:

  1. Выгрузки текста
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. HLO-прототипы
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. Снимки HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. Визуализация графов с помощью сервера Graphviz (хорошо работает только для небольших графов)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. Отображение графика в HTML-файл (хорошо работает только для небольших графиков)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

Для больших графов можно использовать interactive_graphviz для визуализации отдельных его частей.

Вы также можете настроить дампы HLO таким образом, чтобы в качестве имен операций использовались обертки синтаксического сахара, установив флаг --xla_syntax_sugar_async_ops в true . Это может уменьшить размер дампа примерно на 20%. По умолчанию этот флаг установлен в false , и в дампе используются фактические имена операций.

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_syntax_sugar_async_ops=true"

Сбросить промежуточные проходы

Помимо стандартных предварительно оптимизированных/окончательно оптимизированных HLO, вы также можете выводить состояние HLO после определенного прохода компилятора.

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

Модули HLO будут выгружены для проходов, имена которых соответствуют регулярному выражению (regex). Например, вы можете просмотреть модули HLO, полученные в результате проходов, связанных с разделением SPMD, с помощью:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

Чтобы сохранять результаты после каждого прохода XLA (это приведет к созданию большого количества файлов), можно установить следующие параметры:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

Специфические параметры JAX

Программным способом в JAX

Вместо передачи флагов или переменных окружения вы также можете программно выводить HLO, используя API JAX для работы lower доступа и compile .

Получите локально неоптимизированный исходный пониженный HLO с помощью:

jax.jit(f).lower(*args).as_text('hlo')

Для сохранения данных в файлы во время компиляции HLO укажите:

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

Dump jaxprs

jaxpr — это промежуточное представление трассировки программы в JAX. Чтобы вывести его содержимое, установите переменные среды:

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

Подробнее см. в документации JAX по экспорту и сериализации подготовленных вычислений: Отладка .

Google Colab

переменные окружающей среды

В первой выполненной ячейке вашего ноутбука (поскольку переменные среды и флаги командной строки обычно обрабатываются только один раз, например, во время импорта модуля или инициализации бэкенда XLA) добавьте описанные выше XLA_FLAGS с помощью os.environ , например:

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

Это позволит сохранить результаты вычислений в DIRECTORY_PATH , например /tmp . В Colab перейдите в меню «Файлы» в левой боковой панели, чтобы просмотреть и получить доступ к этому каталогу.

Вы можете использовать все флаги, упомянутые в разделе «Локальное выполнение».

Параметры, специфичные для JAX

Аналогично локальному выполнению, для интерактивного анализа в реальном времени можно напрямую вывести предварительно оптимизированный HLO вычисления:

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

Вы также можете напрямую распечатать оптимизированный HLO вычисления:

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

Сброс всех/небольших вычислений

Если вы хотите увидеть все содержимое дампа, включая все небольшие компиляции, установите переменную среды JAX:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

Мозаика

Mosaic — это компилятор для бэкенда Pallas TPU и экспериментального бэкенда Pallas GPU. Для вывода результатов вычислений мозаики установите следующий флаг:

--xla_mosaic_dump_to=/tmp/mosaic_dumps

Или же установите аргументы инициализации TPU в качестве переменной окружения:

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

Для получения более подробной информации ознакомьтесь с документацией JAX по Pallas и Mosaic .

Больше информации с помощью HLO Dumps

Поиск подходящего вычисления

Обычно выгружается большое количество вычислительных данных. Файлы с выгруженными данными имеют явное имя, указывающее на "имя вычисления" в JAX, Tensorflow или PyTorch/XLA, что упрощает идентификацию соответствующих файлов HLO. Например:

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

В противном случае, вы можете использовать ripgrep для быстрого определения того, какой модуль содержит определенные символы или вычисления.

Совет: В отчеты об ошибках включите три файла с дампом до, после и назначением буфера, которые вас интересуют.

Преобразование HLO

Инструмент hlo-opt , который может преобразовывать данные между форматами HLOProto и текстовым форматом. Он полезен в случаях, когда у вас есть один формат, но для отладки нужен другой.

Научитесь им пользоваться: документация по инструментам XLA: hlo-opt .

Повтор

Вы можете запустить (воспроизвести) результаты вычислений, полученные с помощью дампа памяти, на указанном бэкенде XLA, используя фиктивные данные или снимки входных данных. Это удобный способ воспроизведения, итерации и отладки проблем в XLA.

Следующие команды используют фиктивные данные. Если у вас есть сохраненные снимки HLO, вы можете передать их вместо данных, и будут использованы данные из снимка. Чтобы по-прежнему использовать фиктивные данные при запуске снимка, передайте флаг --force_fake_data .

Процессорная часть:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

GPU-бэкэнд:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt