XProf è un ottimo modo per acquisire e visualizzare tracce e profili delle prestazioni del tuo programma, inclusa l'attività su GPU e TPU. Il risultato finale è simile al seguente:

Acquisizione programmatica
Puoi instrumentare il codice per acquisire una traccia del profiler per il codice JAX tramite i metodi
jax.profiler.start_trace
e jax.profiler.stop_trace. Chiama
jax.profiler.start_trace
con la directory in cui scrivere i file di traccia. Deve essere la stessa directory --logdir
utilizzata per avviare XProf. Poi, puoi utilizzare XProf per visualizzare le tracce.
Ad esempio, per acquisire una traccia del profiler:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Prendi nota della chiamata jax.block_until_ready. Utilizziamo questo valore per assicurarci che l'esecuzione sul dispositivo venga acquisita dalla traccia. Consulta
Invio asincrono per
informazioni dettagliate sul motivo per cui è necessario.
Puoi anche utilizzare il gestore del contesto
jax.profiler.trace
come alternativa a start_trace e stop_trace:
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Visualizzazione della traccia
Dopo aver acquisito una traccia, puoi visualizzarla utilizzando la UI di XProf.
Puoi avviare la UI del profiler direttamente utilizzando il comando XProf autonomo puntandolo alla directory dei log:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Vai all'URL fornito (ad es. http://localhost:8791/) nel browser
per visualizzare il profilo.
Le tracce disponibili vengono visualizzate nel menu a discesa "Sessioni" a sinistra. Seleziona la sessione che ti interessa, quindi seleziona "Visualizzatore tracce" nel menu a discesa "Strumenti". Ora dovresti vedere una cronologia dell'esecuzione. Puoi utilizzare i tasti WASD per navigare nella traccia e fare clic o trascinare per selezionare gli eventi per ulteriori dettagli. Per ulteriori dettagli sull'utilizzo del visualizzatore di trace, consulta la documentazione dello strumento Trace Viewer.
Acquisizione manuale tramite XProf
Di seguito sono riportate le istruzioni per acquisire una traccia di N secondi attivata manualmente da un programma in esecuzione.
Avvia un server XProf:
xprof --logdir /tmp/profile-data/Dovresti riuscire a caricare XProf all'indirizzo
<http://localhost:8791/>. Puoi specificare una porta diversa con il flag--port.Nel programma o processo Python che vuoi profilare, aggiungi il seguente codice all'inizio:
import jax.profiler jax.profiler.start_server(9999)Viene avviato il server del profiler a cui si connette XProf. Il server del profiler deve essere in esecuzione prima di procedere con il passaggio successivo. Quando hai finito di utilizzare il server, puoi chiamare
jax.profiler.stop_server()per spegnerlo.Se vuoi profilare un frammento di un programma di lunga durata (ad es. un ciclo di addestramento lungo), puoi inserirlo all'inizio del programma e avviare il programma come di consueto. Se vuoi profilare un programma breve (ad es. un microbenchmark), un'opzione è avviare il server del profiler in una shell IPython ed eseguire il programma breve con
%rundopo aver avviato l'acquisizione nel passaggio successivo. Un'altra opzione è avviare il server del profiler all'inizio del programma e utilizzaretime.sleep()per avere abbastanza tempo per avviare l'acquisizione.Apri
<http://localhost:8791/>e fai clic sul pulsante "ACQUISIZIONE PROFILO" in alto a sinistra. Inserisci "localhost:9999" come URL del servizio di profilazione (questo è l'indirizzo del server del profiler che hai avviato nel passaggio precedente). Inserisci il numero di millisecondi per cui vuoi creare il profilo e fai clic su "ACQUISIZIONE".Se il codice che vuoi profilare non è ancora in esecuzione (ad esempio se hai avviato il server del profiler in una shell Python), eseguilo mentre è in corso l'acquisizione.
Al termine dell'acquisizione, XProf dovrebbe aggiornarsi automaticamente. Non tutte le funzionalità di profilazione di XProf sono collegate a JAX, quindi inizialmente potrebbe sembrare che non sia stato acquisito nulla. A sinistra, nella sezione "Strumenti", seleziona "Visualizzatore tracce".
Ora dovresti vedere una cronologia dell'esecuzione. Puoi utilizzare i tasti WASD per navigare nella traccia e fare clic o trascinare per selezionare gli eventi per visualizzare ulteriori dettagli in basso. Per ulteriori dettagli sull'utilizzo del visualizzatore di trace, consulta la documentazione dello strumento Trace Viewer.
XProf e TensorBoard
XProf è lo strumento sottostante che alimenta la funzionalità di profilazione e acquisizione delle tracce in TensorBoard. Se xprof è installato, in TensorBoard sarà presente una scheda "Profilo". L'utilizzo è identico all'avvio di XProf
in modo indipendente, a condizione che venga avviato puntando alla stessa directory dei log.
Sono incluse le funzionalità di acquisizione, analisi e visualizzazione dei profili. XProf
sostituisce la funzionalità tensorboard_plugin_profile precedentemente
consigliata.
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Aggiunta di eventi di tracciamento personalizzati
Per impostazione predefinita, gli eventi nel visualizzatore di tracce sono principalmente funzioni JAX interne di basso livello. Puoi aggiungere i tuoi eventi e le tue funzioni utilizzando
jax.profiler.TraceAnnotation
e jax.profiler.annotate_function
nel codice.
Configurazione delle opzioni del profiler
Il metodo start_trace accetta un parametro profiler_options facoltativo, che
consente un controllo granulare sul comportamento del profiler. Questo parametro
deve essere un'istanza di jax.profiler.ProfileOptions.
Ad esempio, per disattivare tutte le tracce di Python e dell'host:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Opzioni generali
host_tracer_level: imposta il livello di traccia per le attività lato host.Valori supportati:
0: disattiva completamente la traccia host (CPU).1: abilita la tracciabilità solo degli eventi TraceMe strumentati dall'utente.2: include le tracce di livello 1 più i dettagli di esecuzione del programma di alto livello, ad esempio operazioni XLA costose (impostazione predefinita).3: include le tracce di livello 2 più dettagliate sull'esecuzione del programma di basso livello, ad esempio le operazioni XLA economiche.
device_tracer_level: controlla se il tracciamento del dispositivo è attivato.Valori supportati:
0: disattiva il tracciamento del dispositivo.1: attiva il tracciamento del dispositivo (impostazione predefinita).
python_tracer_level: controlla se la tracciabilità Python è attivata.Valori supportati:
0: disabilita la traccia delle chiamate di funzioni Python (impostazione predefinita).1: attiva la tracciabilità Python.
Opzioni di configurazione avanzate
Opzioni TPU
tpu_trace_mode: specifica la modalità di tracciamento della TPU.Valori supportati:
TRACE_ONLY_HOST: Ciò significa che vengono tracciate solo le attività lato host (CPU) e non vengono raccolte tracce del dispositivo (TPU/GPU).TRACE_ONLY_XLA: significa che vengono tracciate solo le operazioni a livello di XLA sul dispositivo.TRACE_COMPUTE: traccia le operazioni di calcolo sul dispositivo.TRACE_COMPUTE_AND_SYNC: traccia sia le operazioni di calcolo sia gli eventi di sincronizzazione sul dispositivo.
Se "tpu_trace_mode" non viene fornito, trace_mode viene impostato su
TRACE_ONLY_XLAper impostazione predefinita.tpu_num_sparse_cores_to_trace: specifica il numero di core sparsi da tracciare sulla TPU.tpu_num_sparse_core_tiles_to_trace: specifica il numero di riquadri all'interno di ogni core sparso da tracciare sulla TPU.tpu_num_chips_to_profile_per_task: specifica il numero di chip TPU da profilare per attività.
Opzioni GPU
Per la profilazione della GPU sono disponibili le seguenti opzioni:
gpu_max_callback_api_events: imposta il numero massimo di eventi raccolti dall'API di callback CUPTI. Il valore predefinito è2*1024*1024.gpu_max_activity_api_events: imposta il numero massimo di eventi raccolti dall'API CUPTI activity. Il valore predefinito è2*1024*1024.gpu_max_annotation_strings: imposta il numero massimo di stringhe di annotazione che possono essere raccolte. Il valore predefinito è1024*1024.gpu_enable_nvtx_tracking: attiva il monitoraggio NVTX in CUPTI. Il valore predefinito èFalse.gpu_enable_cupti_activity_graph_trace: consente la tracciatura del grafico delle attività CUPTI per i grafici CUDA. Il valore predefinito èFalse.gpu_pm_sample_counters: una stringa separata da virgole di metriche di monitoraggio delle prestazioni della GPU da raccogliere utilizzando la funzionalità di campionamento PM di CUPTI (ad es."sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Il campionamento PM è disattivato per impostazione predefinita. Per le metriche disponibili, consulta la documentazione CUPTI di NVIDIA.gpu_pm_sample_interval_us: imposta l'intervallo di campionamento in microsecondi per il campionamento PM di CUPTI. Il valore predefinito è500.gpu_pm_sample_buffer_size_per_gpu_mb: imposta la dimensione del buffer di memoria di sistema per dispositivo in MB per il campionamento CUPTI PM. Il valore predefinito è 64 MB. Il valore massimo supportato è 4 GB.gpu_num_chips_to_profile_per_task: specifica il numero di dispositivi GPU da profilare per attività. Se non specificato, impostato su 0 o su un valore non valido, verranno profilate tutte le GPU disponibili. Può essere utilizzato per ridurre le dimensioni della raccolta delle tracce.gpu_dump_graph_node_mapping: se abilitato, esegue il dump delle informazioni di mappatura dei nodi del grafico CUDA nella traccia. Il valore predefinito èFalse.
Ad esempio:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Restituisce InvalidArgumentError se vengono trovate chiavi o valori di opzione non riconosciuti.