XProf — отличный способ получить и визуализировать трассировки производительности и профили вашей программы, включая активность на GPU и TPU. В итоге получается примерно следующее:

Программный захват
Вы можете инструментировать свой код для захвата трассировки профилировщика для кода JAX с помощью методов jax.profiler.start_trace и jax.profiler.stop_trace . Вызовите jax.profiler.start_trace , указав каталог, в который будут записываться файлы трассировки. Это должен быть тот же каталог --logdir который использовался для запуска XProf. Затем вы можете использовать XProf для просмотра трассировок.
Например, чтобы получить трассировку профилировщика:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Обратите внимание на вызов jax.block_until_ready . Мы используем его, чтобы гарантировать, что выполнение на устройстве будет зафиксировано трассировкой. Подробнее о том, почему это необходимо, см. в разделе «Асинхронная диспетчеризация» .
В качестве альтернативы функциям start_trace и stop_trace можно также использовать менеджер контекста jax.profiler.trace :
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Просмотр трассировки
После захвата трассировки вы можете просмотреть ее с помощью пользовательского интерфейса XProf.
Вы можете запустить пользовательский интерфейс профилировщика напрямую, используя автономную команду XProf, указав ей путь к каталогу с логами:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Для просмотра профиля перейдите по указанному URL-адресу (например, http://localhost:8791/ ) в вашем браузере.
Доступные трассировки отображаются в раскрывающемся меню «Сессии» слева. Выберите интересующую вас сессию, а затем в раскрывающемся меню «Инструменты» выберите «Просмотр трассировки». Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать для выбора событий и получения более подробной информации. Дополнительные сведения об использовании средства просмотра трассировки см. в документации к инструменту «Просмотр трассировки ».
Ручной захват с помощью XProf
Ниже приведены инструкции по захвату N-секундной трассировки, запускаемой вручную, из работающей программы.
Запустите сервер XProf:
xprof --logdir /tmp/profile-data/Вы сможете загрузить XProf по адресу
<http://localhost:8791/>. Другой порт можно указать с помощью флага--port.В программу или процесс на Python, который вы хотите профилировать, добавьте следующий код где-нибудь в начале:
import jax.profiler jax.profiler.start_server(9999)Это запускает сервер профилирования, к которому подключается XProf. Сервер профилирования должен быть запущен, прежде чем вы перейдете к следующему шагу. Когда вы закончите использовать сервер, вы можете вызвать
jax.profiler.stop_server()чтобы остановить его.Если вы хотите профилировать фрагмент длительной программы (например, длинный цикл обучения), вы можете поместить это в начало программы и запустить ее как обычно. Если вы хотите профилировать короткую программу (например, микротест), один из вариантов — запустить сервер профилирования в оболочке IPython и запустить короткую программу с помощью
%runпосле начала захвата данных на следующем шаге. Другой вариант — запустить сервер профилирования в начале программы и использоватьtime.sleep(), чтобы дать вам достаточно времени для начала захвата данных.Откройте
<http://localhost:8791/>и нажмите кнопку "CAPTURE PROFILE" в верхнем левом углу. Введите "localhost:9999" в качестве URL-адреса службы профилирования (это адрес сервера профилирования, который вы запустили на предыдущем шаге). Введите количество миллисекунд, в течение которых вы хотите выполнить профилирование, и нажмите "CAPTURE".Если код, который вы хотите профилировать, еще не запущен (например, если вы запустили сервер профилирования в оболочке Python), запустите его во время выполнения захвата данных.
После завершения захвата данных XProf должен автоматически обновиться. (Не все функции профилирования XProf подключены к JAX, поэтому изначально может показаться, что ничего не было захвачено.) В левой части экрана в разделе «Инструменты» выберите «Просмотр трассировки».
Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать мышью, чтобы выбрать события и просмотреть более подробную информацию внизу. Более подробную информацию об использовании средства просмотра трассировки см. в документации к инструменту просмотра трассировки .
XProf и Tensorboard
XProf — это базовый инструмент, обеспечивающий функциональность профилирования и захвата трассировки в Tensorboard. Пока xprof установлен, в Tensorboard будет присутствовать вкладка «Профиль». Использовать её можно так же, как и запускать XProf отдельно, при условии, что при запуске указывается тот же каталог с логами. Это включает в себя функции захвата, анализа и просмотра профилирования. XProf заменяет ранее рекомендованную функциональность tensorboard_plugin_profile .
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Добавление пользовательских событий трассировки
По умолчанию события в окне просмотра трассировки представляют собой в основном низкоуровневые внутренние функции JAX. Вы можете добавить свои собственные события и функции, используя jax.profiler.TraceAnnotation и jax.profiler.annotate_function в своем коде.
Настройка параметров профилировщика
Метод start_trace принимает необязательный параметр profiler_options , который позволяет точно управлять поведением профилировщика. Этот параметр должен быть экземпляром класса jax.profiler.ProfileOptions .
Например, чтобы отключить все трассировки Python и хоста:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Общие варианты
host_tracer_level: Задает уровень трассировки для действий на стороне хоста.Поддерживаемые значения:
-
0: Полностью отключает трассировку хоста (ЦП). -
1: Включает отслеживание только событий TraceMe, заданных пользователем. -
2: Включает трассировку уровня 1, а также высокоуровневые сведения о выполнении программы, такие как ресурсоемкие операции XLA (по умолчанию). -
3: Включает трассировку уровня 2, а также более подробные сведения о выполнении программы на низком уровне, такие как операции XLA с низкими затратами ресурсов.
-
device_tracer_level: Определяет, включена ли трассировка устройства.Поддерживаемые значения:
-
0: Отключает отслеживание устройства. -
1: Включает отслеживание устройства (по умолчанию).
-
python_tracer_level: Определяет, включена ли трассировка Python.Поддерживаемые значения:
-
0: Отключает трассировку вызовов функций Python (по умолчанию). -
1: Включает трассировку Python.
-
Расширенные параметры конфигурации
варианты ТПУ
tpu_trace_mode: Задает режим трассировки TPU.Поддерживаемые значения:
-
TRACE_ONLY_HOST: Это означает, что отслеживаются только действия на стороне хоста (ЦП), и данные об устройствах (ТПУ/ГП) не собираются. -
TRACE_ONLY_XLA: Это означает, что отслеживаются только операции на уровне XLA устройства. -
TRACE_COMPUTE: Эта функция отслеживает вычислительные операции на устройстве. -
TRACE_COMPUTE_AND_SYNC: Эта функция отслеживает как вычислительные операции, так и события синхронизации на устройстве.
Если параметр "tpu_trace_mode" не указан, то по умолчанию используется параметр trace_mode
TRACE_ONLY_XLA.-
tpu_num_sparse_cores_to_trace: Задает количество разреженных ядер для трассировки на TPU.tpu_num_sparse_core_tiles_to_trace: Задает количество тайлов в каждом разреженном ядре, которые необходимо трассировать на TPU.tpu_num_chips_to_profile_per_task: Задает количество микросхем TPU, которые необходимо профилировать для каждой задачи.
параметры графического процессора
Для профилирования графического процессора доступны следующие параметры:
-
gpu_max_callback_api_events: Задает максимальное количество событий, собираемых API обратного вызова CUPTI. По умолчанию —2*1024*1024. -
gpu_max_activity_api_events: Задает максимальное количество событий, собираемых API активности CUPTI. По умолчанию —2*1024*1024. -
gpu_max_annotation_strings: Задает максимальное количество строк аннотаций, которые могут быть собраны. По умолчанию —1024*1024. -
gpu_enable_nvtx_tracking: Включает отслеживание NVTX в CUPTI. По умолчанию —False. -
gpu_enable_cupti_activity_graph_trace: Включает трассировку графа активности CUPTI для графов CUDA. По умолчанию —False. -
gpu_pm_sample_counters: Строка, разделенная запятыми, с метриками мониторинга производительности графического процессора, которые будут собираться с помощью функции выборки PM в CUPTI (например,"sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Выборка PM по умолчанию отключена. Список доступных метрик см. в документации NVIDIA по CUPTI . -
gpu_pm_sample_interval_us: Задает интервал выборки в микросекундах для выборки CUPTI PM. По умолчанию —500. -
gpu_pm_sample_buffer_size_per_gpu_mb: Задает размер буфера системной памяти на устройство в МБ для выборки CUPTI PM. По умолчанию — 64 МБ. Максимально поддерживаемое значение — 4 ГБ. -
gpu_num_chips_to_profile_per_task: Задает количество графических процессоров, которые необходимо профилировать для каждой задачи. Если значение не указано, установите его равным 0 или недопустимому значению, будут профилированы все доступные графические процессоры. Это можно использовать для уменьшения размера собираемого массива трассировки. -
gpu_dump_graph_node_mapping: Если включено, выводит информацию о сопоставлении узлов графа CUDA в трассировку. По умолчанию —False.
Например:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Возвращает InvalidArgumentError , если обнаружены какие-либо нераспознанные ключи или значения параметров.