Профилирование вычислений JAX с помощью XProf

XProf — отличный способ получить и визуализировать трассировки производительности и профили вашей программы, включая активность на GPU и TPU. В итоге получается примерно следующее:

Пример XProf

Программный захват

Вы можете инструментировать свой код для захвата трассировки профилировщика для кода JAX с помощью методов jax.profiler.start_trace и jax.profiler.stop_trace . Вызовите jax.profiler.start_trace , указав каталог, в который будут записываться файлы трассировки. Это должен быть тот же каталог --logdir который использовался для запуска XProf. Затем вы можете использовать XProf для просмотра трассировок.

Например, чтобы получить трассировку профилировщика:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Обратите внимание на вызов jax.block_until_ready . Мы используем его, чтобы гарантировать, что выполнение на устройстве будет зафиксировано трассировкой. Подробнее о том, почему это необходимо, см. в разделе «Асинхронная диспетчеризация» .

В качестве альтернативы функциям start_trace и stop_trace можно также использовать менеджер контекста jax.profiler.trace :

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Просмотр трассировки

После захвата трассировки вы можете просмотреть ее с помощью пользовательского интерфейса XProf.

Вы можете запустить пользовательский интерфейс профилировщика напрямую, используя автономную команду XProf, указав ей путь к каталогу с логами:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Для просмотра профиля перейдите по указанному URL-адресу (например, http://localhost:8791/ ) в вашем браузере.

Доступные трассировки отображаются в раскрывающемся меню «Сессии» слева. Выберите интересующую вас сессию, а затем в раскрывающемся меню «Инструменты» выберите «Просмотр трассировки». Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать для выбора событий и получения более подробной информации. Дополнительные сведения об использовании средства просмотра трассировки см. в документации к инструменту «Просмотр трассировки ».

Ручной захват с помощью XProf

Ниже приведены инструкции по захвату N-секундной трассировки, запускаемой вручную, из работающей программы.

  1. Запустите сервер XProf:

    xprof --logdir /tmp/profile-data/
    

    Вы сможете загрузить XProf по адресу <http://localhost:8791/> . Другой порт можно указать с помощью флага --port .

  2. В программу или процесс на Python, который вы хотите профилировать, добавьте следующий код где-нибудь в начале:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Это запускает сервер профилирования, к которому подключается XProf. Сервер профилирования должен быть запущен, прежде чем вы перейдете к следующему шагу. Когда вы закончите использовать сервер, вы можете вызвать jax.profiler.stop_server() чтобы остановить его.

    Если вы хотите профилировать фрагмент длительной программы (например, длинный цикл обучения), вы можете поместить это в начало программы и запустить ее как обычно. Если вы хотите профилировать короткую программу (например, микротест), один из вариантов — запустить сервер профилирования в оболочке IPython и запустить короткую программу с помощью %run после начала захвата данных на следующем шаге. Другой вариант — запустить сервер профилирования в начале программы и использовать time.sleep() , чтобы дать вам достаточно времени для начала захвата данных.

  3. Откройте <http://localhost:8791/> и нажмите кнопку "CAPTURE PROFILE" в верхнем левом углу. Введите "localhost:9999" в качестве URL-адреса службы профилирования (это адрес сервера профилирования, который вы запустили на предыдущем шаге). Введите количество миллисекунд, в течение которых вы хотите выполнить профилирование, и нажмите "CAPTURE".

  4. Если код, который вы хотите профилировать, еще не запущен (например, если вы запустили сервер профилирования в оболочке Python), запустите его во время выполнения захвата данных.

  5. После завершения захвата данных XProf должен автоматически обновиться. (Не все функции профилирования XProf подключены к JAX, поэтому изначально может показаться, что ничего не было захвачено.) В левой части экрана в разделе «Инструменты» выберите «Просмотр трассировки».

Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать мышью, чтобы выбрать события и просмотреть более подробную информацию внизу. Более подробную информацию об использовании средства просмотра трассировки см. в документации к инструменту просмотра трассировки .

XProf и Tensorboard

XProf — это базовый инструмент, обеспечивающий функциональность профилирования и захвата трассировки в Tensorboard. Пока xprof установлен, в Tensorboard будет присутствовать вкладка «Профиль». Использовать её можно так же, как и запускать XProf отдельно, при условии, что при запуске указывается тот же каталог с логами. Это включает в себя функции захвата, анализа и просмотра профилирования. XProf заменяет ранее рекомендованную функциональность tensorboard_plugin_profile .

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Добавление пользовательских событий трассировки

По умолчанию события в окне просмотра трассировки представляют собой в основном низкоуровневые внутренние функции JAX. Вы можете добавить свои собственные события и функции, используя jax.profiler.TraceAnnotation и jax.profiler.annotate_function в своем коде.

Настройка параметров профилировщика

Метод start_trace принимает необязательный параметр profiler_options , который позволяет точно управлять поведением профилировщика. Этот параметр должен быть экземпляром класса jax.profiler.ProfileOptions .

Например, чтобы отключить все трассировки Python и хоста:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Общие варианты

  1. host_tracer_level : Задает уровень трассировки для действий на стороне хоста.

    Поддерживаемые значения:

    • 0 : Полностью отключает трассировку хоста (ЦП).
    • 1 : Включает отслеживание только событий TraceMe, заданных пользователем.
    • 2 : Включает трассировку уровня 1, а также высокоуровневые сведения о выполнении программы, такие как ресурсоемкие операции XLA (по умолчанию).
    • 3 : Включает трассировку уровня 2, а также более подробные сведения о выполнении программы на низком уровне, такие как операции XLA с низкими затратами ресурсов.
  2. device_tracer_level : Определяет, включена ли трассировка устройства.

    Поддерживаемые значения:

    • 0 : Отключает отслеживание устройства.
    • 1 : Включает отслеживание устройства (по умолчанию).
  3. python_tracer_level : Определяет, включена ли трассировка Python.

    Поддерживаемые значения:

    • 0 : Отключает трассировку вызовов функций Python (по умолчанию).
    • 1 : Включает трассировку Python.

Расширенные параметры конфигурации

варианты ТПУ

  1. tpu_trace_mode : Задает режим трассировки TPU.

    Поддерживаемые значения:

    • TRACE_ONLY_HOST : Это означает, что отслеживаются только действия на стороне хоста (ЦП), и данные об устройствах (ТПУ/ГП) не собираются.
    • TRACE_ONLY_XLA : Это означает, что отслеживаются только операции на уровне XLA устройства.
    • TRACE_COMPUTE : Эта функция отслеживает вычислительные операции на устройстве.
    • TRACE_COMPUTE_AND_SYNC : Эта функция отслеживает как вычислительные операции, так и события синхронизации на устройстве.

    Если параметр "tpu_trace_mode" не указан, то по умолчанию используется параметр trace_mode TRACE_ONLY_XLA .

  2. tpu_num_sparse_cores_to_trace : Задает количество разреженных ядер для трассировки на TPU.

  3. tpu_num_sparse_core_tiles_to_trace : Задает количество тайлов в каждом разреженном ядре, которые необходимо трассировать на TPU.

  4. tpu_num_chips_to_profile_per_task : Задает количество микросхем TPU, которые необходимо профилировать для каждой задачи.

параметры графического процессора

Для профилирования графического процессора доступны следующие параметры:

  • gpu_max_callback_api_events : Задает максимальное количество событий, собираемых API обратного вызова CUPTI. По умолчанию — 2*1024*1024 .
  • gpu_max_activity_api_events : Задает максимальное количество событий, собираемых API активности CUPTI. По умолчанию — 2*1024*1024 .
  • gpu_max_annotation_strings : Задает максимальное количество строк аннотаций, которые могут быть собраны. По умолчанию — 1024*1024 .
  • gpu_enable_nvtx_tracking : Включает отслеживание NVTX в CUPTI. По умолчанию — False .
  • gpu_enable_cupti_activity_graph_trace : Включает трассировку графа активности CUPTI для графов CUDA. По умолчанию — False .
  • gpu_pm_sample_counters : Строка, разделенная запятыми, с метриками мониторинга производительности графического процессора, которые будут собираться с помощью функции выборки PM в CUPTI (например, "sm__cycles_active.avg.pct_of_peak_sustained_elapsed" ). Выборка PM по умолчанию отключена. Список доступных метрик см. в документации NVIDIA по CUPTI .
  • gpu_pm_sample_interval_us : Задает интервал выборки в микросекундах для выборки CUPTI PM. По умолчанию — 500 .
  • gpu_pm_sample_buffer_size_per_gpu_mb : Задает размер буфера системной памяти на устройство в МБ для выборки CUPTI PM. По умолчанию — 64 МБ. Максимально поддерживаемое значение — 4 ГБ.
  • gpu_num_chips_to_profile_per_task : Задает количество графических процессоров, которые необходимо профилировать для каждой задачи. Если значение не указано, установите его равным 0 или недопустимому значению, будут профилированы все доступные графические процессоры. Это можно использовать для уменьшения размера собираемого массива трассировки.
  • gpu_dump_graph_node_mapping : Если включено, выводит информацию о сопоставлении узлов графа CUDA в трассировку. По умолчанию — False .

Например:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Возвращает InvalidArgumentError , если обнаружены какие-либо нераспознанные ключи или значения параметров.