Mit XProf können Sie Leistungstraces und ‑profile Ihres Programms erfassen und visualisieren, einschließlich der Aktivität auf GPU und TPU. Das Endergebnis sieht in etwa so aus:

Programmatische Erfassung
Sie können Ihren Code instrumentieren, um einen Profiler-Trace für JAX-Code über die Methoden jax.profiler.start_trace und jax.profiler.stop_trace zu erfassen. Rufen Sie jax.profiler.start_trace mit dem Verzeichnis auf, in das die Tracedateien geschrieben werden sollen. Dies sollte dasselbe --logdir-Verzeichnis sein, das zum Starten von XProf verwendet wurde. Anschließend können Sie die Traces mit XProf ansehen.
So erstellen Sie beispielsweise einen Profiler-Trace:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Beachten Sie den Aufruf von jax.block_until_ready. So wird sichergestellt, dass die Ausführung auf dem Gerät im Trace erfasst wird. Weitere Informationen dazu, warum dies erforderlich ist, finden Sie unter Asynchrones Senden.
Alternativ zu start_trace und stop_trace können Sie auch den Kontextmanager jax.profiler.trace verwenden:
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Trace ansehen
Nachdem Sie einen Trace erfasst haben, können Sie ihn über die XProf-Benutzeroberfläche ansehen.
Sie können die Profiler-Benutzeroberfläche direkt mit dem eigenständigen XProf-Befehl starten, indem Sie ihn auf Ihr Logverzeichnis verweisen:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Rufen Sie die angegebene URL auf (z.B. http://localhost:8791/) in Ihrem Browser aufrufen, um das Profil aufzurufen.
Verfügbare Traces werden links im Drop-down-Menü „Sitzungen“ angezeigt. Wählen Sie die gewünschte Sitzung aus und dann im Drop-down-Menü „Tools“ die Option „Trace Viewer“. Sie sollten jetzt eine Zeitachse der Ausführung sehen. Sie können die WASD-Tasten verwenden, um im Trace zu navigieren, und klicken oder ziehen, um Ereignisse auszuwählen und weitere Details aufzurufen. Weitere Informationen zur Verwendung des Trace-Viewers finden Sie in der Dokumentation zum Trace Viewer-Tool.
Manuelle Erfassung über XProf
Im Folgenden finden Sie eine Anleitung zum Erfassen eines manuell ausgelösten N-Sekunden-Traces aus einem laufenden Programm.
Starten Sie einen XProf-Server:
xprof --logdir /tmp/profile-data/Sie sollten XProf unter
<http://localhost:8791/>laden können. Mit dem Flag--portkönnen Sie einen anderen Port angeben.Fügen Sie dem Python-Programm oder -Prozess, für den Sie ein Profil erstellen möchten, den folgenden Code an einer Stelle in der Nähe des Anfangs hinzu:
import jax.profiler jax.profiler.start_server(9999)Dadurch wird der Profiler-Server gestartet, mit dem XProf eine Verbindung herstellt. Der Profiler-Server muss ausgeführt werden, bevor Sie mit dem nächsten Schritt fortfahren. Wenn Sie den Server nicht mehr benötigen, können Sie ihn mit
jax.profiler.stop_server()herunterfahren.Wenn Sie ein Segment eines lang laufenden Programms (z.B. eine lange Trainingsschleife) profilieren möchten, können Sie es an den Anfang des Programms stellen und das Programm wie gewohnt starten. Wenn Sie ein kurzes Programm (z.B. einen Microbenchmark) profilieren möchten, können Sie den Profiler-Server in einer IPython-Shell starten und das kurze Programm mit
%runausführen, nachdem Sie die Erfassung im nächsten Schritt gestartet haben. Eine weitere Möglichkeit besteht darin, den Profiler-Server am Anfang des Programms zu starten undtime.sleep()zu verwenden, um genügend Zeit zum Starten der Erfassung zu haben.Öffnen Sie
<http://localhost:8791/>und klicken Sie links oben auf die Schaltfläche „CAPTURE PROFILE“. Geben Sie „localhost:9999“ als Profiler-Dienst-URL ein. Das ist die Adresse des Profiler-Servers, den Sie im vorherigen Schritt gestartet haben. Geben Sie die Anzahl der Millisekunden ein, für die Sie ein Profil erstellen möchten, und klicken Sie auf „CAPTURE“ (ERFASSEN).Wenn der Code, den Sie profilieren möchten, noch nicht ausgeführt wird (z.B. wenn Sie den Profiler-Server in einer Python-Shell gestartet haben), führen Sie ihn aus, während die Erfassung läuft.
Nach Abschluss der Erfassung sollte XProf automatisch aktualisiert werden. Nicht alle XProf-Profilerstellungsfunktionen sind mit JAX verknüpft. Es kann also anfangs so aussehen, als wäre nichts erfasst worden. Wählen Sie links unter „Tools“ die Option „Trace Viewer“ aus.
Sie sollten jetzt eine Zeitachse der Ausführung sehen. Mit den WASD-Tasten können Sie im Trace navigieren. Klicken oder ziehen Sie, um Ereignisse auszuwählen und unten weitere Details aufzurufen. Weitere Informationen zur Verwendung des Trace-Viewers finden Sie in der Dokumentation zum Trace-Viewer-Tool.
XProf und TensorBoard
XProf ist das zugrunde liegende Tool, das die Funktionen zur Profilerstellung und zum Erfassen von Traces in TensorBoard unterstützt. Solange xprof installiert ist, ist in TensorBoard ein Tab „Profil“ vorhanden. Die Verwendung ist identisch mit dem unabhängigen Starten von XProf, sofern es auf dasselbe Logverzeichnis verweist.
Dazu gehören Funktionen zum Erfassen, Analysieren und Ansehen von Profilen. XProf ersetzt die tensorboard_plugin_profile-Funktion, die zuvor empfohlen wurde.
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Benutzerdefinierte Trace-Ereignisse hinzufügen
Standardmäßig sind die Ereignisse im Trace Viewer meist interne JAX-Funktionen auf niedriger Ebene. Sie können Ihrem Code eigene Ereignisse und Funktionen hinzufügen, indem Sie jax.profiler.TraceAnnotation und jax.profiler.annotate_function verwenden.
Profiler-Optionen konfigurieren
Die Methode start_trace akzeptiert einen optionalen profiler_options-Parameter, mit dem sich das Verhalten des Profilers detailliert steuern lässt. Dieser Parameter sollte eine Instanz von jax.profiler.ProfileOptions sein.
So deaktivieren Sie beispielsweise alle Python- und Host-Traces:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Allgemeine Optionen
host_tracer_level: Legt die Trace-Ebene für hostseitige Aktivitäten fest.Unterstützte Werte:
0: Deaktiviert das Host-Tracing (CPU) vollständig.1: Ermöglicht das Tracing von TraceMe-Ereignissen, die nur vom Nutzer instrumentiert wurden.2: Enthält Traces der Ebene 1 sowie allgemeine Details zur Programmausführung, z. B. rechenintensive XLA-Vorgänge (Standard).3: Enthält Traces der Stufe 2 sowie ausführlichere Details zur Programmausführung auf niedriger Ebene, z. B. kostengünstige XLA-Vorgänge.
device_tracer_level: Steuert, ob die Geräteverfolgung aktiviert ist.Unterstützte Werte:
0: Deaktiviert das Geräte-Tracing.1: Aktiviert die Geräteverfolgung (Standard).
python_tracer_level: Steuert, ob Python-Tracing aktiviert ist.Unterstützte Werte:
0: Deaktiviert das Tracing von Python-Funktionsaufrufen (Standard).1: Aktiviert die Python-Ablaufverfolgung.
Erweiterte Konfigurationsoptionen
TPU-Optionen
tpu_trace_mode: Gibt den Modus für das TPU-Tracing an.Unterstützte Werte:
TRACE_ONLY_HOST: Das bedeutet, dass nur Aktivitäten auf der Hostseite (CPU) erfasst werden und keine Geräte-Traces (TPU/GPU) gesammelt werden.TRACE_ONLY_XLA: Das bedeutet, dass nur Vorgänge auf XLA-Ebene auf dem Gerät erfasst werden.TRACE_COMPUTE: Damit werden Rechenvorgänge auf dem Gerät nachverfolgt.TRACE_COMPUTE_AND_SYNC: Damit werden sowohl Rechenvorgänge als auch Synchronisierungsereignisse auf dem Gerät verfolgt.
Wenn „tpu_trace_mode“ nicht angegeben ist, wird standardmäßig
TRACE_ONLY_XLAverwendet.tpu_num_sparse_cores_to_trace: Gibt die Anzahl der zu verfolgenden spärlichen Kerne auf der TPU an.tpu_num_sparse_core_tiles_to_trace: Gibt die Anzahl der Kacheln in jedem spärlichen Kern an, die auf der TPU verfolgt werden sollen.tpu_num_chips_to_profile_per_task: Gibt die Anzahl der TPU-Chips an, die pro Aufgabe profiliert werden sollen.
GPU-Optionen
Für das GPU-Profiling stehen die folgenden Optionen zur Verfügung:
gpu_max_callback_api_events: Legt die maximale Anzahl von Ereignissen fest, die von der CUPTI-Callback-API erfasst werden. Die Standardeinstellung ist2*1024*1024.gpu_max_activity_api_events: Legt die maximale Anzahl von Ereignissen fest, die von der CUPTI-Aktivitäts-API erfasst werden. Die Standardeinstellung ist2*1024*1024.gpu_max_annotation_strings: Legt die maximale Anzahl der Annotationsstrings fest, die erfasst werden können. Die Standardeinstellung ist1024*1024.gpu_enable_nvtx_tracking: Aktiviert NVTX-Tracking in CUPTI. Die Standardeinstellung istFalse.gpu_enable_cupti_activity_graph_trace: Aktiviert die CUPTI-Aktivitätsgraph-Tracerstellung für CUDA-Graphen. Die Standardeinstellung istFalse.gpu_pm_sample_counters: Eine durch Kommas getrennte Liste von Messwerten für das GPU-Leistungsmonitoring, die mit der PM-Sampling-Funktion von CUPTI erfasst werden sollen (z.B."sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). PM-Sampling ist standardmäßig deaktiviert. Informationen zu den verfügbaren Messwerten finden Sie in der CUPTI-Dokumentation von NVIDIA.gpu_pm_sample_interval_us: Legt das Abtastintervall in Mikrosekunden für die CUPTI-PM-Abtastung fest. Die Standardeinstellung ist500.gpu_pm_sample_buffer_size_per_gpu_mb: Legt die Größe des Systemarbeitsspeicherpuffers pro Gerät in MB für das CUPTI PM-Sampling fest. Die Standardeinstellung ist 64 MB. Der maximal unterstützte Wert beträgt 4 GB.gpu_num_chips_to_profile_per_task: Gibt die Anzahl der GPU-Geräte an, die pro Aufgabe profiliert werden sollen. Wenn nicht angegeben, auf 0 oder auf einen ungültigen Wert festgelegt, werden alle verfügbaren GPUs profiliert. Damit kann die Größe der Trace-Erfassung verringert werden.gpu_dump_graph_node_mapping: Wenn diese Option aktiviert ist, werden Informationen zur Zuordnung von CUDA-Graphknoten in den Trace ausgegeben. Die Standardeinstellung istFalse.
Beispiel:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Gibt InvalidArgumentError zurück, wenn nicht erkannte Schlüssel oder Optionswerte gefunden werden.