XProf est un excellent moyen d'acquérir et de visualiser les traces et les profils de performances de votre programme, y compris l'activité sur GPU et TPU. Le résultat final ressemble à ceci :

Capture programmatique
Vous pouvez instrumenter votre code pour capturer une trace de profileur pour le code JAX à l'aide des méthodes jax.profiler.start_trace et jax.profiler.stop_trace. Appelez jax.profiler.start_trace avec le répertoire dans lequel écrire les fichiers de trace. Il doit s'agir du même répertoire --logdir que celui utilisé pour démarrer XProf. Vous pouvez ensuite utiliser XProf pour afficher les traces.
Par exemple, pour effectuer une trace du profileur :
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Notez l'appel jax.block_until_ready. Nous l'utilisons pour nous assurer que l'exécution sur l'appareil est capturée par la trace. Pour en savoir plus sur la raison pour laquelle cela est nécessaire, consultez Envoi asynchrone.
Vous pouvez également utiliser le gestionnaire de contexte jax.profiler.trace comme alternative à start_trace et stop_trace :
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Afficher la trace
Une fois la trace capturée, vous pouvez l'afficher à l'aide de l'interface utilisateur XProf.
Vous pouvez lancer l'interface utilisateur du profileur directement à l'aide de la commande XProf autonome en la pointant vers votre répertoire de journaux :
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Accédez à l'URL fournie (par exemple, http://localhost:8791/) dans votre navigateur pour afficher le profil.
Les traces disponibles s'affichent dans le menu déroulant "Sessions" à gauche. Sélectionnez la session qui vous intéresse, puis dans le menu déroulant "Outils", sélectionnez "Trace Viewer". Vous devriez maintenant voir une chronologie de l'exécution. Vous pouvez utiliser les touches WASD pour parcourir la trace, et cliquer ou faire glisser pour sélectionner des événements et obtenir plus de détails. Pour en savoir plus sur l'utilisation de l'outil Trace Viewer, consultez sa documentation.
Capture manuelle via XProf
Vous trouverez ci-dessous les instructions pour capturer une trace de N secondes déclenchée manuellement à partir d'un programme en cours d'exécution.
Démarrez un serveur XProf :
xprof --logdir /tmp/profile-data/Vous devriez pouvoir charger XProf à l'adresse
<http://localhost:8791/>. Vous pouvez spécifier un autre port avec l'option--port.Dans le programme ou processus Python que vous souhaitez profiler, ajoutez le code suivant près du début :
import jax.profiler jax.profiler.start_server(9999)Cela démarre le serveur du profileur auquel XProf se connecte. Le serveur du profileur doit être en cours d'exécution avant de passer à l'étape suivante. Lorsque vous avez terminé d'utiliser le serveur, vous pouvez appeler
jax.profiler.stop_server()pour l'arrêter.Si vous souhaitez profiler un extrait d'un programme de longue durée (par exemple, une longue boucle d'entraînement), vous pouvez le placer au début du programme et démarrer votre programme comme d'habitude. Si vous souhaitez profiler un programme court (par exemple, un microbenchmark), vous pouvez démarrer le serveur de profileur dans un shell IPython et exécuter le programme court avec
%runaprès avoir démarré la capture à l'étape suivante. Une autre option consiste à démarrer le serveur du profileur au début du programme et à utilisertime.sleep()pour vous donner suffisamment de temps pour démarrer la capture.Ouvrez
<http://localhost:8791/>, puis cliquez sur le bouton "CAPTURE PROFILE" (Capturer le profil) en haut à gauche. Saisissez "localhost:9999" comme URL du service de profil (il s'agit de l'adresse du serveur du profileur que vous avez démarré à l'étape précédente). Saisissez le nombre de millisecondes pour lequel vous souhaitez profiler, puis cliquez sur "CAPTURE".Si le code que vous souhaitez profiler n'est pas déjà en cours d'exécution (par exemple, si vous avez démarré le serveur de profileur dans un shell Python), exécutez-le pendant que la capture est en cours.
Une fois la capture terminée, XProf devrait s'actualiser automatiquement. (Toutes les fonctionnalités de profilage XProf ne sont pas connectées à JAX. Il peut donc sembler au premier abord qu'aucune donnée n'a été capturée.) Sur la gauche, sous "Outils", sélectionnez "Trace Viewer".
Vous devriez maintenant voir une chronologie de l'exécution. Vous pouvez utiliser les touches WASD pour parcourir la trace, et cliquer ou faire glisser pour sélectionner des événements afin d'afficher plus de détails en bas. Pour en savoir plus sur l'utilisation du traceur, consultez la documentation de l'outil Trace Viewer.
XProf et TensorBoard
XProf est l'outil sous-jacent qui alimente les fonctionnalités de profilage et de capture de trace dans TensorBoard. Tant que xprof est installé, un onglet "Profil" est présent dans TensorBoard. Son utilisation est identique au lancement indépendant de XProf, à condition qu'il soit lancé en pointant vers le même répertoire de journaux.
Cela inclut les fonctionnalités de capture, d'analyse et d'affichage des profils. XProf remplace la fonctionnalité tensorboard_plugin_profile qui était auparavant recommandée.
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Ajouter des événements de trace personnalisés
Par défaut, les événements du lecteur de trace sont principalement des fonctions JAX internes de bas niveau. Vous pouvez ajouter vos propres événements et fonctions à l'aide de jax.profiler.TraceAnnotation et jax.profiler.annotate_function dans votre code.
Configurer les options du profileur
La méthode start_trace accepte un paramètre profiler_options facultatif, qui permet un contrôle précis du comportement du profileur. Ce paramètre doit être une instance de jax.profiler.ProfileOptions.
Par exemple, pour désactiver toutes les traces Python et d'hôte :
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Options générales
host_tracer_level: définit le niveau de trace pour les activités côté hôte.Valeurs autorisées :
0: désactive complètement le traçage de l'hôte (CPU).1: active le traçage des événements TraceMe instrumentés par l'utilisateur uniquement.2: inclut les traces de niveau 1 ainsi que des informations générales sur l'exécution du programme, comme les opérations XLA coûteuses (par défaut).3: inclut les traces de niveau 2 ainsi que des informations plus détaillées sur l'exécution du programme de bas niveau, telles que les opérations XLA peu coûteuses.
device_tracer_level: contrôle si le traçage des appareils est activé.Valeurs autorisées :
0: désactive le traçage des appareils.1: active le traçage des appareils (par défaut).
python_tracer_level: contrôle si le traçage Python est activé.Valeurs autorisées :
0: désactive le traçage des appels de fonction Python (valeur par défaut).1: active le traçage Python.
Options de configuration avancées
Options de TPU
tpu_trace_mode: spécifie le mode de traçage du TPU.Valeurs autorisées :
TRACE_ONLY_HOST: cela signifie que seules les activités côté hôte (CPU) sont tracées, et qu'aucune trace d'appareil (TPU/GPU) n'est collectée.TRACE_ONLY_XLA: cela signifie que seules les opérations au niveau XLA sur l'appareil sont tracées.TRACE_COMPUTE: cette option permet de suivre les opérations de calcul sur l'appareil.TRACE_COMPUTE_AND_SYNC: cela permet de suivre les opérations de calcul et les événements de synchronisation sur l'appareil.
Si "tpu_trace_mode" n'est pas fourni, trace_mode est défini par défaut sur
TRACE_ONLY_XLA.tpu_num_sparse_cores_to_trace: spécifie le nombre de cœurs épars à tracer sur le TPU.tpu_num_sparse_core_tiles_to_trace: spécifie le nombre de blocs dans chaque cœur creux à tracer sur le TPU.tpu_num_chips_to_profile_per_task: spécifie le nombre de puces TPU à profiler par tâche.
Options de GPU
Les options suivantes sont disponibles pour le profilage GPU :
gpu_max_callback_api_events: définit le nombre maximal d'événements collectés par l'API de rappel CUPTI. La valeur par défaut est2*1024*1024.gpu_max_activity_api_events: définit le nombre maximal d'événements collectés par l'API d'activité CUPTI. La valeur par défaut est2*1024*1024.gpu_max_annotation_strings: définit le nombre maximal de chaînes d'annotation pouvant être collectées. La valeur par défaut est1024*1024.gpu_enable_nvtx_tracking: active le suivi NVTX dans CUPTI. La valeur par défaut estFalse.gpu_enable_cupti_activity_graph_trace: active le traçage du graphique d'activité CUPTI pour les graphiques CUDA. La valeur par défaut estFalse.gpu_pm_sample_counters: chaîne de métriques de surveillance des performances du GPU à collecter à l'aide de la fonctionnalité d'échantillonnage PM de CUPTI (par exemple,"sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). L'échantillonnage PM est désactivé par défaut. Pour connaître les métriques disponibles, consultez la documentation CUPTI de NVIDIA.gpu_pm_sample_interval_us: définit l'intervalle d'échantillonnage en microsecondes pour l'échantillonnage CUPTI PM. La valeur par défaut est500.gpu_pm_sample_buffer_size_per_gpu_mb: définit la taille du tampon de mémoire système par appareil en Mo pour l'échantillonnage CUPTI PM. La valeur par défaut est de 64 Mo. La valeur maximale autorisée est de 4 Go.gpu_num_chips_to_profile_per_task: spécifie le nombre d'appareils GPU à profiler par tâche. Si cette valeur n'est pas spécifiée, est définie sur 0 ou est définie sur une valeur non valide, tous les GPU disponibles seront profilés. Cela peut être utilisé pour réduire la taille de la collecte de traces.gpu_dump_graph_node_mapping: si cette option est activée, les informations de mappage des nœuds du graphique CUDA sont incluses dans la trace. La valeur par défaut estFalse.
Exemple :
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Renvoie InvalidArgumentError si des clés ou des valeurs d'option non reconnues sont trouvées.