Profiler les calculs JAX avec XProf

XProf est un excellent moyen d'acquérir et de visualiser les traces et les profils de performances de votre programme, y compris l'activité sur GPU et TPU. Le résultat final ressemble à ceci :

Exemple XProf

Capture programmatique

Vous pouvez instrumenter votre code pour capturer une trace de profileur pour le code JAX à l'aide des méthodes jax.profiler.start_trace et jax.profiler.stop_trace. Appelez jax.profiler.start_trace avec le répertoire dans lequel écrire les fichiers de trace. Il doit s'agir du même répertoire --logdir que celui utilisé pour démarrer XProf. Vous pouvez ensuite utiliser XProf pour afficher les traces.

Par exemple, pour effectuer une trace du profileur :

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Notez l'appel jax.block_until_ready. Nous l'utilisons pour nous assurer que l'exécution sur l'appareil est capturée par la trace. Pour en savoir plus sur la raison pour laquelle cela est nécessaire, consultez Envoi asynchrone.

Vous pouvez également utiliser le gestionnaire de contexte jax.profiler.trace comme alternative à start_trace et stop_trace :

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Afficher la trace

Une fois la trace capturée, vous pouvez l'afficher à l'aide de l'interface utilisateur XProf.

Vous pouvez lancer l'interface utilisateur du profileur directement à l'aide de la commande XProf autonome en la pointant vers votre répertoire de journaux :

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Accédez à l'URL fournie (par exemple, http://localhost:8791/) dans votre navigateur pour afficher le profil.

Les traces disponibles s'affichent dans le menu déroulant "Sessions" à gauche. Sélectionnez la session qui vous intéresse, puis dans le menu déroulant "Outils", sélectionnez "Trace Viewer". Vous devriez maintenant voir une chronologie de l'exécution. Vous pouvez utiliser les touches WASD pour parcourir la trace, et cliquer ou faire glisser pour sélectionner des événements et obtenir plus de détails. Pour en savoir plus sur l'utilisation de l'outil Trace Viewer, consultez sa documentation.

Capture manuelle via XProf

Vous trouverez ci-dessous les instructions pour capturer une trace de N secondes déclenchée manuellement à partir d'un programme en cours d'exécution.

  1. Démarrez un serveur XProf :

    xprof --logdir /tmp/profile-data/
    

    Vous devriez pouvoir charger XProf à l'adresse <http://localhost:8791/>. Vous pouvez spécifier un autre port avec l'option --port.

  2. Dans le programme ou processus Python que vous souhaitez profiler, ajoutez le code suivant près du début :

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Cela démarre le serveur du profileur auquel XProf se connecte. Le serveur du profileur doit être en cours d'exécution avant de passer à l'étape suivante. Lorsque vous avez terminé d'utiliser le serveur, vous pouvez appeler jax.profiler.stop_server() pour l'arrêter.

    Si vous souhaitez profiler un extrait d'un programme de longue durée (par exemple, une longue boucle d'entraînement), vous pouvez le placer au début du programme et démarrer votre programme comme d'habitude. Si vous souhaitez profiler un programme court (par exemple, un microbenchmark), vous pouvez démarrer le serveur de profileur dans un shell IPython et exécuter le programme court avec %run après avoir démarré la capture à l'étape suivante. Une autre option consiste à démarrer le serveur du profileur au début du programme et à utiliser time.sleep() pour vous donner suffisamment de temps pour démarrer la capture.

  3. Ouvrez <http://localhost:8791/>, puis cliquez sur le bouton "CAPTURE PROFILE" (Capturer le profil) en haut à gauche. Saisissez "localhost:9999" comme URL du service de profil (il s'agit de l'adresse du serveur du profileur que vous avez démarré à l'étape précédente). Saisissez le nombre de millisecondes pour lequel vous souhaitez profiler, puis cliquez sur "CAPTURE".

  4. Si le code que vous souhaitez profiler n'est pas déjà en cours d'exécution (par exemple, si vous avez démarré le serveur de profileur dans un shell Python), exécutez-le pendant que la capture est en cours.

  5. Une fois la capture terminée, XProf devrait s'actualiser automatiquement. (Toutes les fonctionnalités de profilage XProf ne sont pas connectées à JAX. Il peut donc sembler au premier abord qu'aucune donnée n'a été capturée.) Sur la gauche, sous "Outils", sélectionnez "Trace Viewer".

Vous devriez maintenant voir une chronologie de l'exécution. Vous pouvez utiliser les touches WASD pour parcourir la trace, et cliquer ou faire glisser pour sélectionner des événements afin d'afficher plus de détails en bas. Pour en savoir plus sur l'utilisation du traceur, consultez la documentation de l'outil Trace Viewer.

XProf et TensorBoard

XProf est l'outil sous-jacent qui alimente les fonctionnalités de profilage et de capture de trace dans TensorBoard. Tant que xprof est installé, un onglet "Profil" est présent dans TensorBoard. Son utilisation est identique au lancement indépendant de XProf, à condition qu'il soit lancé en pointant vers le même répertoire de journaux. Cela inclut les fonctionnalités de capture, d'analyse et d'affichage des profils. XProf remplace la fonctionnalité tensorboard_plugin_profile qui était auparavant recommandée.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Ajouter des événements de trace personnalisés

Par défaut, les événements du lecteur de trace sont principalement des fonctions JAX internes de bas niveau. Vous pouvez ajouter vos propres événements et fonctions à l'aide de jax.profiler.TraceAnnotation et jax.profiler.annotate_function dans votre code.

Configurer les options du profileur

La méthode start_trace accepte un paramètre profiler_options facultatif, qui permet un contrôle précis du comportement du profileur. Ce paramètre doit être une instance de jax.profiler.ProfileOptions.

Par exemple, pour désactiver toutes les traces Python et d'hôte :

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Options générales

  1. host_tracer_level : définit le niveau de trace pour les activités côté hôte.

    Valeurs autorisées :

    • 0 : désactive complètement le traçage de l'hôte (CPU).
    • 1 : active le traçage des événements TraceMe instrumentés par l'utilisateur uniquement.
    • 2 : inclut les traces de niveau 1 ainsi que des informations générales sur l'exécution du programme, comme les opérations XLA coûteuses (par défaut).
    • 3 : inclut les traces de niveau 2 ainsi que des informations plus détaillées sur l'exécution du programme de bas niveau, telles que les opérations XLA peu coûteuses.
  2. device_tracer_level : contrôle si le traçage des appareils est activé.

    Valeurs autorisées :

    • 0 : désactive le traçage des appareils.
    • 1 : active le traçage des appareils (par défaut).
  3. python_tracer_level : contrôle si le traçage Python est activé.

    Valeurs autorisées :

    • 0 : désactive le traçage des appels de fonction Python (valeur par défaut).
    • 1 : active le traçage Python.

Options de configuration avancées

Options de TPU

  1. tpu_trace_mode : spécifie le mode de traçage du TPU.

    Valeurs autorisées :

    • TRACE_ONLY_HOST : cela signifie que seules les activités côté hôte (CPU) sont tracées, et qu'aucune trace d'appareil (TPU/GPU) n'est collectée.
    • TRACE_ONLY_XLA : cela signifie que seules les opérations au niveau XLA sur l'appareil sont tracées.
    • TRACE_COMPUTE : cette option permet de suivre les opérations de calcul sur l'appareil.
    • TRACE_COMPUTE_AND_SYNC : cela permet de suivre les opérations de calcul et les événements de synchronisation sur l'appareil.

    Si "tpu_trace_mode" n'est pas fourni, trace_mode est défini par défaut sur TRACE_ONLY_XLA.

  2. tpu_num_sparse_cores_to_trace : spécifie le nombre de cœurs épars à tracer sur le TPU.

  3. tpu_num_sparse_core_tiles_to_trace : spécifie le nombre de blocs dans chaque cœur creux à tracer sur le TPU.

  4. tpu_num_chips_to_profile_per_task : spécifie le nombre de puces TPU à profiler par tâche.

Options de GPU

Les options suivantes sont disponibles pour le profilage GPU :

  • gpu_max_callback_api_events : définit le nombre maximal d'événements collectés par l'API de rappel CUPTI. La valeur par défaut est 2*1024*1024.
  • gpu_max_activity_api_events : définit le nombre maximal d'événements collectés par l'API d'activité CUPTI. La valeur par défaut est 2*1024*1024.
  • gpu_max_annotation_strings : définit le nombre maximal de chaînes d'annotation pouvant être collectées. La valeur par défaut est 1024*1024.
  • gpu_enable_nvtx_tracking : active le suivi NVTX dans CUPTI. La valeur par défaut est False.
  • gpu_enable_cupti_activity_graph_trace : active le traçage du graphique d'activité CUPTI pour les graphiques CUDA. La valeur par défaut est False.
  • gpu_pm_sample_counters : chaîne de métriques de surveillance des performances du GPU à collecter à l'aide de la fonctionnalité d'échantillonnage PM de CUPTI (par exemple, "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). L'échantillonnage PM est désactivé par défaut. Pour connaître les métriques disponibles, consultez la documentation CUPTI de NVIDIA.
  • gpu_pm_sample_interval_us : définit l'intervalle d'échantillonnage en microsecondes pour l'échantillonnage CUPTI PM. La valeur par défaut est 500.
  • gpu_pm_sample_buffer_size_per_gpu_mb : définit la taille du tampon de mémoire système par appareil en Mo pour l'échantillonnage CUPTI PM. La valeur par défaut est de 64 Mo. La valeur maximale autorisée est de 4 Go.
  • gpu_num_chips_to_profile_per_task : spécifie le nombre d'appareils GPU à profiler par tâche. Si cette valeur n'est pas spécifiée, est définie sur 0 ou est définie sur une valeur non valide, tous les GPU disponibles seront profilés. Cela peut être utilisé pour réduire la taille de la collecte de traces.
  • gpu_dump_graph_node_mapping : si cette option est activée, les informations de mappage des nœuds du graphique CUDA sont incluses dans la trace. La valeur par défaut est False.

Exemple :

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Renvoie InvalidArgumentError si des clés ou des valeurs d'option non reconnues sont trouvées.