Genera perfiles de los cálculos de JAX con XProf

XProf es una excelente manera de adquirir y visualizar los registros y perfiles de rendimiento de tu programa, incluida la actividad en la GPU y la TPU. El resultado final debería ser similar al siguiente:

Ejemplo de XProf

Captura programática

Puedes instrumentar tu código para capturar un registro del profiler para el código de JAX a través de los métodos jax.profiler.start_trace y jax.profiler.stop_trace. Llama a jax.profiler.start_trace con el directorio en el que se escribirán los archivos de registro. Debe ser el mismo directorio --logdir que se usó para iniciar XProf. Luego, puedes usar XProf para ver los registros.

Por ejemplo, para tomar un registro del generador de perfiles, haz lo siguiente:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Observa la llamada jax.block_until_ready. Usamos esto para asegurarnos de que el registro capture la ejecución en el dispositivo. Consulta Envío asíncrono para obtener detalles sobre por qué es necesario.

También puedes usar el administrador de contexto jax.profiler.trace como alternativa a start_trace y stop_trace:

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Cómo ver el registro

Después de capturar un registro, puedes verlo con la IU de XProf.

Puedes iniciar la IU del generador de perfiles directamente con el comando independiente de XProf apuntándolo a tu directorio de registros:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Navega a la URL proporcionada (p.ej., http://localhost:8791/) en tu navegador para ver el perfil.

Los registros disponibles aparecen en el menú desplegable "Sessions" de la izquierda. Selecciona la sesión que te interesa y, luego, en el menú desplegable "Herramientas", selecciona "Visor de registros". Ahora deberías ver una línea de tiempo de la ejecución. Puedes usar las teclas WASD para navegar por el registro y hacer clic o arrastrar el cursor para seleccionar eventos y obtener más detalles. Consulta la documentación de la herramienta Trace Viewer para obtener más detalles sobre cómo usarla.

Captura manual a través de XProf

A continuación, se incluyen instrucciones para capturar un registro de N segundos activado manualmente desde un programa en ejecución.

  1. Inicia un servidor de XProf:

    xprof --logdir /tmp/profile-data/
    

    Deberías poder cargar XProf en <http://localhost:8791/>. Puedes especificar un puerto diferente con la marca --port.

  2. En el programa o proceso de Python que deseas analizar, agrega lo siguiente cerca del comienzo:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Esto inicia el servidor del generador de perfiles al que se conecta XProf. El servidor del generador de perfiles debe estar en ejecución antes de que pases al siguiente paso. Cuando termines de usar el servidor, puedes llamar a jax.profiler.stop_server() para apagarlo.

    Si deseas generar un perfil de un fragmento de un programa de larga duración (p.ej., un bucle de entrenamiento largo), puedes colocar esto al principio del programa y ejecutarlo como de costumbre. Si deseas generar un perfil de un programa corto (p.ej., una microcomparativa), una opción es iniciar el servidor del generador de perfiles en un shell de IPython y ejecutar el programa corto con %run después de iniciar la captura en el siguiente paso. Otra opción es iniciar el servidor del generador de perfiles al comienzo del programa y usar time.sleep() para tener tiempo suficiente para iniciar la captura.

  3. Abre <http://localhost:8791/> y haz clic en el botón "CAPTURE PROFILE" en la parte superior izquierda. Ingresa "localhost:9999" como la URL del servicio de perfil (esta es la dirección del servidor del generador de perfiles que iniciaste en el paso anterior). Ingresa la cantidad de milisegundos para los que deseas crear el perfil y haz clic en "CAPTURE".

  4. Si el código que deseas analizar aún no se está ejecutando (p.ej., si iniciaste el servidor del analizador en un shell de Python), ejecútalo mientras se realiza la captura.

  5. Una vez que finalice la captura, XProf debería actualizarse automáticamente. (No todas las funciones de generación de perfiles de XProf están conectadas con JAX, por lo que, en un principio, puede parecer que no se capturó nada). A la izquierda, en "Herramientas", selecciona "Visor de registros".

Ahora deberías ver una línea de tiempo de la ejecución. Puedes usar las teclas WASD para navegar por el registro y hacer clic o arrastrar el cursor para seleccionar eventos y ver más detalles en la parte inferior. Consulta la documentación de la herramienta Trace Viewer para obtener más detalles sobre cómo usar el visualizador de registros.

XProf y TensorBoard

XProf es la herramienta subyacente que impulsa la funcionalidad de generación de perfiles y captura de registros en TensorBoard. Siempre que xprof esté instalado, habrá una pestaña "Perfil" en TensorBoard. Usar esta opción es idéntico a iniciar XProf de forma independiente, siempre y cuando se inicie apuntando al mismo directorio de registros. Esto incluye la funcionalidad de captura, análisis y visualización de perfiles. XProf reemplaza la funcionalidad de tensorboard_plugin_profile que se recomendaba anteriormente.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Cómo agregar eventos de seguimiento personalizados

De forma predeterminada, los eventos en el lector de seguimiento son, en su mayoría, funciones internas de JAX de bajo nivel. Puedes agregar tus propios eventos y funciones con jax.profiler.TraceAnnotation y jax.profiler.annotate_function en tu código.

Cómo configurar las opciones del generador de perfiles

El método start_trace acepta un parámetro profiler_options opcional, que permite un control detallado sobre el comportamiento del generador de perfiles. Este parámetro debe ser una instancia de jax.profiler.ProfileOptions.

Por ejemplo, para inhabilitar todos los registros de Python y del host, haz lo siguiente:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Opciones generales

  1. host_tracer_level: Establece el nivel de registro para las actividades del host.

    Valores admitidos:

    • 0: Inhabilita por completo el registro del host (CPU).
    • 1: Habilita el registro de solo los eventos TraceMe instrumentados por el usuario.
    • 2: Incluye registros de nivel 1 y detalles de ejecución del programa de alto nivel, como operaciones de XLA costosas (opción predeterminada).
    • 3: Incluye registros de nivel 2 y detalles más detallados de la ejecución del programa de bajo nivel, como las operaciones de XLA económicas.
  2. device_tracer_level: Controla si el registro del dispositivo está habilitado.

    Valores admitidos:

    • 0: Inhabilita el registro del dispositivo.
    • 1: Habilita el registro del dispositivo (predeterminado).
  3. python_tracer_level: Controla si el registro de Python está habilitado.

    Valores admitidos:

    • 0: Inhabilita el registro de llamadas a funciones de Python (predeterminado).
    • 1: Habilita el registro de Python.

Opciones de configuración avanzada

Opciones de TPU

  1. tpu_trace_mode: Especifica el modo para el registro de la TPU.

    Valores admitidos:

    • TRACE_ONLY_HOST: Esto significa que solo se registran las actividades del host (CPU) y no se recopilan registros del dispositivo (TPU/GPU).
    • TRACE_ONLY_XLA: Esto significa que solo se rastrean las operaciones a nivel de XLA en el dispositivo.
    • TRACE_COMPUTE: Registra las operaciones de procesamiento en el dispositivo.
    • TRACE_COMPUTE_AND_SYNC: Registra las operaciones de procesamiento y los eventos de sincronización en el dispositivo.

    Si no se proporciona "tpu_trace_mode", el valor predeterminado de trace_mode es TRACE_ONLY_XLA.

  2. tpu_num_sparse_cores_to_trace: Especifica la cantidad de núcleos dispersos que se rastrearán en la TPU.

  3. tpu_num_sparse_core_tiles_to_trace: Especifica la cantidad de mosaicos dentro de cada núcleo disperso que se rastrearán en la TPU.

  4. tpu_num_chips_to_profile_per_task: Especifica la cantidad de chips TPU que se deben analizar por tarea.

Opciones de GPU

Las siguientes opciones están disponibles para la generación de perfiles de la GPU:

  • gpu_max_callback_api_events: Establece la cantidad máxima de eventos recopilados por la API de devolución de llamada de CUPTI. La configuración predeterminada es 2*1024*1024.
  • gpu_max_activity_api_events: Establece la cantidad máxima de eventos recopilados por la API de actividad de CUPTI. La configuración predeterminada es 2*1024*1024.
  • gpu_max_annotation_strings: Establece la cantidad máxima de cadenas de anotación que se pueden recopilar. La configuración predeterminada es 1024*1024.
  • gpu_enable_nvtx_tracking: Habilita el seguimiento de NVTX en CUPTI. La configuración predeterminada es False.
  • gpu_enable_cupti_activity_graph_trace: Habilita el registro del gráfico de actividad de CUPTI para los gráficos de CUDA. La configuración predeterminada es False.
  • gpu_pm_sample_counters: Es una cadena separada por comas de métricas de supervisión del rendimiento de la GPU que se recopilarán con la función de muestreo de PM de CUPTI (p.ej., "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). El muestreo de PM está inhabilitado de forma predeterminada. Para conocer las métricas disponibles, consulta la documentación de CUPTI de NVIDIA.
  • gpu_pm_sample_interval_us: Establece el intervalo de muestreo en microsegundos para el muestreo de PM de CUPTI. La configuración predeterminada es 500.
  • gpu_pm_sample_buffer_size_per_gpu_mb: Establece el tamaño del búfer de memoria del sistema por dispositivo en MB para el muestreo de PM de CUPTI. El valor predeterminado es 64 MB. El valor máximo admitido es de 4 GB.
  • gpu_num_chips_to_profile_per_task: Especifica la cantidad de dispositivos de GPU que se deben analizar por tarea. Si no se especifica, se establece en 0 o se establece en un valor no válido, se generará el perfil de todas las GPUs disponibles. Esto se puede usar para reducir el tamaño de la recopilación de registros.
  • gpu_dump_graph_node_mapping: Si está habilitado, vuelca la información de la asignación de nodos del grafo de CUDA en el registro. La configuración predeterminada es False.

Por ejemplo:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Devuelve InvalidArgumentError si se encuentran claves o valores de opciones no reconocidos.