XProf — отличный способ получить и визуализировать трассировки производительности и профили вашей программы, включая активность на GPU и TPU. В итоге получается примерно следующее:

Программный захват
Вы можете инструментировать свой код для захвата трассировки профилировщика для кода JAX с помощью методов jax.profiler.start_trace и jax.profiler.stop_trace . Вызовите jax.profiler.start_trace , указав каталог, в который будут записываться файлы трассировки. Это должен быть тот же каталог --logdir который использовался для запуска XProf. Затем вы можете использовать XProf для просмотра трассировок.
Например, чтобы получить трассировку профилировщика:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Обратите внимание на вызов jax.block_until_ready . Мы используем его, чтобы гарантировать, что выполнение на устройстве будет зафиксировано трассировкой. Подробнее о том, почему это необходимо, см. в разделе «Асинхронная диспетчеризация» .
В качестве альтернативы функциям start_trace и stop_trace можно также использовать менеджер контекста jax.profiler.trace :
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Просмотр трассировки
После захвата трассировки вы можете просмотреть ее с помощью пользовательского интерфейса XProf.
Вы можете запустить пользовательский интерфейс профилировщика напрямую, используя автономную команду XProf, указав ей путь к каталогу с логами:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Для просмотра профиля перейдите по указанному URL-адресу (например, http://localhost:8791/ ) в вашем браузере.
Доступные трассировки отображаются в раскрывающемся меню «Сессии» слева. Выберите интересующую вас сессию, а затем в раскрывающемся меню «Инструменты» выберите «Просмотр трассировки». Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать для выбора событий и получения более подробной информации. Дополнительные сведения об использовании средства просмотра трассировки см. в документации к инструменту «Просмотр трассировки ».
Ручной захват с помощью XProf
Ниже приведены инструкции по захвату N-секундной трассировки, запускаемой вручную, из работающей программы.
Запустите сервер XProf:
xprof --logdir /tmp/profile-data/Вы сможете загрузить XProf по адресу
<http://localhost:8791/>. Другой порт можно указать с помощью флага--port.В программу или процесс на Python, который вы хотите профилировать, добавьте следующий код где-нибудь в начале:
import jax.profiler jax.profiler.start_server(9999)Это запускает сервер профилирования, к которому подключается XProf. Сервер профилирования должен быть запущен, прежде чем вы перейдете к следующему шагу. Когда вы закончите использовать сервер, вы можете вызвать
jax.profiler.stop_server()чтобы остановить его.Если вы хотите профилировать фрагмент длительной программы (например, длинный цикл обучения), вы можете поместить это в начало программы и запустить ее как обычно. Если вы хотите профилировать короткую программу (например, микротест), один из вариантов — запустить сервер профилирования в оболочке IPython и запустить короткую программу с помощью
%runпосле начала захвата данных на следующем шаге. Другой вариант — запустить сервер профилирования в начале программы и использоватьtime.sleep(), чтобы дать вам достаточно времени для начала захвата данных.Откройте
<http://localhost:8791/>и нажмите кнопку "CAPTURE PROFILE" в верхнем левом углу. Введите "localhost:9999" в качестве URL-адреса службы профилирования (это адрес сервера профилирования, который вы запустили на предыдущем шаге). Введите количество миллисекунд, в течение которых вы хотите выполнить профилирование, и нажмите "CAPTURE".Если код, который вы хотите профилировать, еще не запущен (например, если вы запустили сервер профилирования в оболочке Python), запустите его во время выполнения захвата данных.
После завершения захвата данных XProf должен автоматически обновиться. (Не все функции профилирования XProf подключены к JAX, поэтому изначально может показаться, что ничего не было захвачено.) В левой части экрана в разделе «Инструменты» выберите «Просмотр трассировки».
Теперь вы должны увидеть временную шкалу выполнения. Вы можете использовать клавиши WASD для навигации по трассировке, а также щелкать или перетаскивать мышью, чтобы выбрать события и просмотреть более подробную информацию внизу. Более подробную информацию об использовании средства просмотра трассировки см. в документации к инструменту просмотра трассировки .
Создание непрерывных снимков профилирования.
Ниже приведены инструкции по созданию непрерывного снимка профилирования в любой момент времени.
- В программу или процесс на Python, который вы хотите профилировать, добавьте следующий код где-нибудь в начале:
import jax.profiler
jax.profiler.start_server(9999)
- Начните непрерывное профилирование независимо от вашей программы:
from xprof.api import continuous_profiling_snapshot
continuous_profiling_snapshot.start_continuous_profiling('localhost:9999', {})
- Сделать снимок:
from xprof.api import continuous_profiling_snapshot
continuous_profiling_snapshot.get_snapshot('localhost:9999', '/tmp/profile-data/')
- Остановить непрерывное профилирование:
from xprof.api import continuous_profiling_snapshot
continuous_profiling_snapshot.stop_continuous_profiling('localhost:9999')
- Запуск XProf:
xprof --port=8791 /tmp/profile-data
You should now see a timeline of the execution. You can use the WASD keys to navigate the trace, and click or drag to select events to see more details at the bottom. See the Trace Viewer Tool documentation for more details on using the trace viewer.
XProf и Tensorboard
XProf — это базовый инструмент, обеспечивающий функциональность профилирования и захвата трассировки в Tensorboard. Пока xprof установлен, в Tensorboard будет присутствовать вкладка «Профиль». Использовать её можно так же, как и запускать XProf отдельно, при условии, что при запуске указывается тот же каталог с логами. Это включает в себя функции захвата, анализа и просмотра профилирования. XProf заменяет ранее рекомендованную функциональность tensorboard_plugin_profile .
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Добавление пользовательских событий трассировки
По умолчанию события в окне просмотра трассировки представляют собой в основном низкоуровневые внутренние функции JAX. Вы можете добавить свои собственные события и функции, используя jax.profiler.TraceAnnotation и jax.profiler.annotate_function в своем коде.
Настройка параметров профилировщика
Метод start_trace принимает необязательный параметр profiler_options , который позволяет точно управлять поведением профилировщика. Этот параметр должен быть экземпляром класса jax.profiler.ProfileOptions .
Например, чтобы отключить все трассировки Python и хоста:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Общие варианты
host_tracer_level: Задает уровень трассировки для действий на стороне хоста.Поддерживаемые значения:
-
0: Полностью отключает трассировку хоста (ЦП). -
1: Включает отслеживание только событий TraceMe, заданных пользователем. -
2: Включает трассировку уровня 1, а также высокоуровневые сведения о выполнении программы, такие как ресурсоемкие операции XLA (по умолчанию). -
3: Включает трассировку уровня 2, а также более подробные сведения о выполнении программы на низком уровне, такие как операции XLA с низкими затратами ресурсов.
-
device_tracer_level: Определяет, включена ли трассировка устройства.Поддерживаемые значения:
-
0: Отключает отслеживание устройства. -
1: Включает отслеживание устройства (по умолчанию).
-
python_tracer_level: Определяет, включена ли трассировка Python.Поддерживаемые значения:
-
0: Отключает трассировку вызовов функций Python (по умолчанию). -
1: Включает трассировку Python.
-
Расширенные параметры конфигурации
варианты ТПУ
tpu_trace_mode: Задает режим трассировки TPU.Поддерживаемые значения:
-
TRACE_ONLY_HOST: Это означает, что отслеживаются только действия на стороне хоста (ЦП), и данные об устройствах (ТПУ/ГП) не собираются. -
TRACE_ONLY_XLA: Это означает, что отслеживаются только операции на уровне XLA устройства. -
TRACE_COMPUTE: Эта функция отслеживает вычислительные операции на устройстве. -
TRACE_COMPUTE_AND_SYNC: Эта функция отслеживает как вычислительные операции, так и события синхронизации на устройстве.
Если параметр "tpu_trace_mode" не указан, то по умолчанию используется параметр trace_mode
TRACE_ONLY_XLA.-
tpu_num_sparse_cores_to_trace: Задает количество разреженных ядер для трассировки на TPU.tpu_num_sparse_core_tiles_to_trace: Задает количество тайлов в каждом разреженном ядре, которые необходимо трассировать на TPU.tpu_num_chips_to_profile_per_task: Задает количество микросхем TPU, которые необходимо профилировать для каждой задачи.
параметры графического процессора
Для профилирования графического процессора доступны следующие параметры:
-
gpu_max_callback_api_events: Задает максимальное количество событий, собираемых API обратного вызова CUPTI. По умолчанию —2*1024*1024. -
gpu_max_activity_api_events: Задает максимальное количество событий, собираемых API активности CUPTI. По умолчанию —2*1024*1024. -
gpu_max_annotation_strings: Задает максимальное количество строк аннотаций, которые могут быть собраны. По умолчанию —1024*1024. -
gpu_enable_nvtx_tracking: Включает отслеживание NVTX в CUPTI. По умолчанию —False. -
gpu_enable_cupti_activity_graph_trace: Включает трассировку графа активности CUPTI для графов CUDA. По умолчанию —False. -
gpu_pm_sample_counters: Строка, разделенная запятыми, с метриками мониторинга производительности графического процессора, которые будут собираться с помощью функции выборки PM в CUPTI (например,"sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Выборка PM по умолчанию отключена. Список доступных метрик см. в документации NVIDIA по CUPTI . -
gpu_pm_sample_interval_us: Задает интервал выборки в микросекундах для выборки CUPTI PM. По умолчанию —500. -
gpu_pm_sample_buffer_size_per_gpu_mb: Задает размер буфера системной памяти на устройство в МБ для выборки CUPTI PM. По умолчанию — 64 МБ. Максимально поддерживаемое значение — 4 ГБ. -
gpu_num_chips_to_profile_per_task: Задает количество графических процессоров, которые необходимо профилировать для каждой задачи. Если значение не указано, установите его равным 0 или недопустимому значению, будут профилированы все доступные графические процессоры. Это можно использовать для уменьшения размера собираемого массива трассировки. -
gpu_dump_graph_node_mapping: Если включено, выводит информацию о сопоставлении узлов графа CUDA в трассировку. По умолчанию —False.
Например:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Возвращает InvalidArgumentError , если обнаружены какие-либо нераспознанные ключи или значения параметров.