Дамп вычислений HLO

Дамп HLO — это текстовое представление модулей HLO на разных этапах вычислений. Он полезен для отладки и часто используется в отчётах об ошибках. Обычно это текстовый файл, удобочитаемый для человека, содержащий список инструкций HLO и их свойств. Иногда модули HLO выводятся в виде:

  • HloProto: файлы буфера протокола, представляющие собой более структурированный, машиночитаемый формат.
  • HloSnapshot : модуль HLO и его входные данные. Для воспроизведения HLO иногда требуются фактические входные данные, подаваемые в данное вычисление, а не случайные данные.

Для задания и получения дампов можно использовать флаги XLA. В большинстве случаев это можно сделать с помощью переменной окружения. JAX также предлагает программный способ вывода дампа HLO.

Местное исполнение

Использование переменных среды

Вы можете установить переменную окружения XLA_FLAGS с необходимыми флагами для получения дампов. Это работает для JAX, TensorFlow и PyTorch/XLA.

Чтобы выгрузить модули HLO и другую отладочную информацию в определенный каталог, запустите программу с флагом --xla_dump_to :

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

Например, в качестве путей можно использовать /tmp или /tmp/xladump .

По умолчанию модули HLO выводятся в виде текста в самом начале и в конце конвейера оптимизации.

Вы также можете явно указать формат:

  1. Текстовые дампы
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. HLO протос
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. Снимки HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. Рендеринг графиков с помощью сервера GraphViz (работает только для небольших графиков)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. Рендеринг графика в HTML-файл (работает только для небольших графиков)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

Для визуализации частей больших графиков можно использовать interactive_graphviz .

Промежуточные проходы для сброса данных

Помимо стандартных предварительно оптимизированных/окончательно оптимизированных HLO, вы также можете вывести состояние HLO после определенного прохода компилятора.

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

Модули HLO будут выгружены для проходов, имена которых соответствуют регулярному выражению (regex). Например, вы можете наблюдать HLO, полученные в проходах, связанных с разбиением SPMD, с помощью:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

Чтобы выводить результат после каждого прохода XLA (это приведет к созданию большого количества файлов), можно установить:

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

Параметры, специфичные для JAX

Программно в JAX

Вместо передачи флагов или переменных окружения вы также можете программно вывести HLO, используя API JAX compile и lower .

Локально получить неоптимизированный исходный пониженный HLO с помощью:

jax.jit(f).lower(*args).as_text('hlo')

Для выполнения дампа в файлы во время проходов компиляции HLO укажите:

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

Дамп jaxprs

jaxpr — это промежуточное представление трассировки программы в JAX. Чтобы вывести её, установите переменные окружения:

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

Дополнительную информацию см. в документации JAX по экспорту и сериализации завершенных вычислений: отладка .

Google Colab

Переменные среды

В первой выполняемой ячейке вашего блокнота (поскольку переменные среды и флаги командной строки обычно обрабатываются только один раз, например, во время импорта модуля или во время инициализации бэкэнда XLA) добавьте XLA_FLAGS подробно описанные выше, с помощью os.environ , например:

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

Это перенесёт вычисления в DIRECTORY_PATH , например /tmp . В Colab перейдите в браузер «Файлы» на левой боковой панели, чтобы просмотреть и получить доступ к этому каталогу.

Вы можете использовать все флаги, упомянутые в разделе «Локальное выполнение».

Параметры, специфичные для JAX

Подобно локальному выполнению; для живого интерактивного самоанализа вы можете напрямую распечатать предварительно оптимизированный HLO вычисления:

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

Вы также можете напрямую распечатать оптимизированный HLO вычисления:

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

Сброс всех/небольших вычислений

Если вы хотите увидеть все в дампе, включая все небольшие компиляции, установите переменную среды JAX:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

Мозаика

Mosaic — это компилятор для бэкенда Pallas TPU и экспериментального бэкенда Pallas GPU. Чтобы вывести данные для мозаичного вычисления, установите следующий флаг:

--xla_mosaic_dump_to=/tmp/mosaic_dumps

Или задайте аргументы инициализации TPU как переменную среды:

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

Более подробную информацию можно найти в документации JAX по Pallas и Mosaic .

Ещё с HLO Dumps

Поиск правильного вычисления

Обычно выполняется дамп множества вычислений. Файлы дампов явно именуются с использованием «имени вычисления» JAX, Tensorflow или PyTorch/XLA, которое указывается в журналах, что позволяет легко идентифицировать соответствующие файлы HLO. Например:

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

В противном случае вы можете использовать ripgrep для быстрого определения того, какой модуль содержит конкретные символы или вычисления.

Совет: включите в свои отчеты об ошибках три интересующих вас файла «до/после/назначения буфера».

Преобразование HLO

Инструмент hlo-opt , который может конвертировать данные между форматами HLOProto и текстом. Он полезен в случаях, когда у вас есть один формат, а другой нужен для отладки.

Изучите его: документация по инструментам XLA: hlo-opt .

Повторить

Вы можете запустить (воспроизвести) дамп вычислений на указанном бэкенде XLA с поддельными данными или снимками входных данных. Это удобный способ воспроизведения, итерации и отладки проблем в XLA.

Следующие команды используют поддельные данные. Если у вас есть сохраненные снимки HLO, вы можете передать их, и будут использоваться данные из снимка. Чтобы использовать поддельные данные при запуске снимка, передайте флаг --force_fake_data .

Процессорный бэкэнд:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

Бэкэнд GPU:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt