Profilazione dei calcoli JAX con XProf

XProf è un ottimo modo per acquisire e visualizzare tracce e profili delle prestazioni del tuo programma, inclusa l'attività su GPU e TPU. Il risultato finale è simile al seguente:

Esempio di XProf

Acquisizione programmatica

Puoi instrumentare il codice per acquisire una traccia del profiler per il codice JAX tramite i metodi jax.profiler.start_trace e jax.profiler.stop_trace. Chiama jax.profiler.start_trace con la directory in cui scrivere i file di traccia. Deve essere la stessa directory --logdir utilizzata per avviare XProf. Poi, puoi utilizzare XProf per visualizzare le tracce.

Ad esempio, per acquisire una traccia del profiler:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Prendi nota della chiamata jax.block_until_ready. Utilizziamo questo valore per assicurarci che l'esecuzione sul dispositivo venga acquisita dalla traccia. Consulta Invio asincrono per informazioni dettagliate sul motivo per cui è necessario.

Puoi anche utilizzare il gestore del contesto jax.profiler.trace come alternativa a start_trace e stop_trace:

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Visualizzazione della traccia

Dopo aver acquisito una traccia, puoi visualizzarla utilizzando la UI di XProf.

Puoi avviare la UI del profiler direttamente utilizzando il comando XProf autonomo puntandolo alla directory dei log:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Vai all'URL fornito (ad es. http://localhost:8791/) nel browser per visualizzare il profilo.

Le tracce disponibili vengono visualizzate nel menu a discesa "Sessioni" a sinistra. Seleziona la sessione che ti interessa, quindi seleziona "Visualizzatore tracce" nel menu a discesa "Strumenti". Ora dovresti vedere una cronologia dell'esecuzione. Puoi utilizzare i tasti WASD per navigare nella traccia e fare clic o trascinare per selezionare gli eventi per ulteriori dettagli. Per ulteriori dettagli sull'utilizzo del visualizzatore di trace, consulta la documentazione dello strumento Trace Viewer.

Acquisizione manuale tramite XProf

Di seguito sono riportate le istruzioni per acquisire una traccia di N secondi attivata manualmente da un programma in esecuzione.

  1. Avvia un server XProf:

    xprof --logdir /tmp/profile-data/
    

    Dovresti riuscire a caricare XProf all'indirizzo <http://localhost:8791/>. Puoi specificare una porta diversa con il flag --port.

  2. Nel programma o processo Python che vuoi profilare, aggiungi il seguente codice all'inizio:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Viene avviato il server del profiler a cui si connette XProf. Il server del profiler deve essere in esecuzione prima di procedere con il passaggio successivo. Quando hai finito di utilizzare il server, puoi chiamare jax.profiler.stop_server() per spegnerlo.

    Se vuoi profilare un frammento di un programma di lunga durata (ad es. un ciclo di addestramento lungo), puoi inserirlo all'inizio del programma e avviare il programma come di consueto. Se vuoi profilare un programma breve (ad es. un microbenchmark), un'opzione è avviare il server del profiler in una shell IPython ed eseguire il programma breve con %run dopo aver avviato l'acquisizione nel passaggio successivo. Un'altra opzione è avviare il server del profiler all'inizio del programma e utilizzare time.sleep() per avere abbastanza tempo per avviare l'acquisizione.

  3. Apri <http://localhost:8791/> e fai clic sul pulsante "ACQUISIZIONE PROFILO" in alto a sinistra. Inserisci "localhost:9999" come URL del servizio di profilazione (questo è l'indirizzo del server del profiler che hai avviato nel passaggio precedente). Inserisci il numero di millisecondi per cui vuoi creare il profilo e fai clic su "ACQUISIZIONE".

  4. Se il codice che vuoi profilare non è ancora in esecuzione (ad esempio se hai avviato il server del profiler in una shell Python), eseguilo mentre è in corso l'acquisizione.

  5. Al termine dell'acquisizione, XProf dovrebbe aggiornarsi automaticamente. Non tutte le funzionalità di profilazione di XProf sono collegate a JAX, quindi inizialmente potrebbe sembrare che non sia stato acquisito nulla. A sinistra, nella sezione "Strumenti", seleziona "Visualizzatore tracce".

Ora dovresti vedere una cronologia dell'esecuzione. Puoi utilizzare i tasti WASD per navigare nella traccia e fare clic o trascinare per selezionare gli eventi per visualizzare ulteriori dettagli in basso. Per ulteriori dettagli sull'utilizzo del visualizzatore di trace, consulta la documentazione dello strumento Trace Viewer.

XProf e TensorBoard

XProf è lo strumento sottostante che alimenta la funzionalità di profilazione e acquisizione delle tracce in TensorBoard. Se xprof è installato, in TensorBoard sarà presente una scheda "Profilo". L'utilizzo è identico all'avvio di XProf in modo indipendente, a condizione che venga avviato puntando alla stessa directory dei log. Sono incluse le funzionalità di acquisizione, analisi e visualizzazione dei profili. XProf sostituisce la funzionalità tensorboard_plugin_profile precedentemente consigliata.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Aggiunta di eventi di tracciamento personalizzati

Per impostazione predefinita, gli eventi nel visualizzatore di tracce sono principalmente funzioni JAX interne di basso livello. Puoi aggiungere i tuoi eventi e le tue funzioni utilizzando jax.profiler.TraceAnnotation e jax.profiler.annotate_function nel codice.

Configurazione delle opzioni del profiler

Il metodo start_trace accetta un parametro profiler_options facoltativo, che consente un controllo granulare sul comportamento del profiler. Questo parametro deve essere un'istanza di jax.profiler.ProfileOptions.

Ad esempio, per disattivare tutte le tracce di Python e dell'host:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Opzioni generali

  1. host_tracer_level: imposta il livello di traccia per le attività lato host.

    Valori supportati:

    • 0: disattiva completamente la traccia host (CPU).
    • 1: abilita la tracciabilità solo degli eventi TraceMe strumentati dall'utente.
    • 2: include le tracce di livello 1 più i dettagli di esecuzione del programma di alto livello, ad esempio operazioni XLA costose (impostazione predefinita).
    • 3: include le tracce di livello 2 più dettagliate sull'esecuzione del programma di basso livello, ad esempio le operazioni XLA economiche.
  2. device_tracer_level: controlla se il tracciamento del dispositivo è attivato.

    Valori supportati:

    • 0: disattiva il tracciamento del dispositivo.
    • 1: attiva il tracciamento del dispositivo (impostazione predefinita).
  3. python_tracer_level: controlla se la tracciabilità Python è attivata.

    Valori supportati:

    • 0: disabilita la traccia delle chiamate di funzioni Python (impostazione predefinita).
    • 1: attiva la tracciabilità Python.

Opzioni di configurazione avanzate

Opzioni TPU

  1. tpu_trace_mode: specifica la modalità di tracciamento della TPU.

    Valori supportati:

    • TRACE_ONLY_HOST: Ciò significa che vengono tracciate solo le attività lato host (CPU) e non vengono raccolte tracce del dispositivo (TPU/GPU).
    • TRACE_ONLY_XLA: significa che vengono tracciate solo le operazioni a livello di XLA sul dispositivo.
    • TRACE_COMPUTE: traccia le operazioni di calcolo sul dispositivo.
    • TRACE_COMPUTE_AND_SYNC: traccia sia le operazioni di calcolo sia gli eventi di sincronizzazione sul dispositivo.

    Se "tpu_trace_mode" non viene fornito, trace_mode viene impostato su TRACE_ONLY_XLA per impostazione predefinita.

  2. tpu_num_sparse_cores_to_trace: specifica il numero di core sparsi da tracciare sulla TPU.

  3. tpu_num_sparse_core_tiles_to_trace: specifica il numero di riquadri all'interno di ogni core sparso da tracciare sulla TPU.

  4. tpu_num_chips_to_profile_per_task: specifica il numero di chip TPU da profilare per attività.

Opzioni GPU

Per la profilazione della GPU sono disponibili le seguenti opzioni:

  • gpu_max_callback_api_events: imposta il numero massimo di eventi raccolti dall'API di callback CUPTI. Il valore predefinito è 2*1024*1024.
  • gpu_max_activity_api_events: imposta il numero massimo di eventi raccolti dall'API CUPTI activity. Il valore predefinito è 2*1024*1024.
  • gpu_max_annotation_strings: imposta il numero massimo di stringhe di annotazione che possono essere raccolte. Il valore predefinito è 1024*1024.
  • gpu_enable_nvtx_tracking: attiva il monitoraggio NVTX in CUPTI. Il valore predefinito è False.
  • gpu_enable_cupti_activity_graph_trace: consente la tracciatura del grafico delle attività CUPTI per i grafici CUDA. Il valore predefinito è False.
  • gpu_pm_sample_counters: una stringa separata da virgole di metriche di monitoraggio delle prestazioni della GPU da raccogliere utilizzando la funzionalità di campionamento PM di CUPTI (ad es. "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Il campionamento PM è disattivato per impostazione predefinita. Per le metriche disponibili, consulta la documentazione CUPTI di NVIDIA.
  • gpu_pm_sample_interval_us: imposta l'intervallo di campionamento in microsecondi per il campionamento PM di CUPTI. Il valore predefinito è 500.
  • gpu_pm_sample_buffer_size_per_gpu_mb: imposta la dimensione del buffer di memoria di sistema per dispositivo in MB per il campionamento CUPTI PM. Il valore predefinito è 64 MB. Il valore massimo supportato è 4 GB.
  • gpu_num_chips_to_profile_per_task: specifica il numero di dispositivi GPU da profilare per attività. Se non specificato, impostato su 0 o su un valore non valido, verranno profilate tutte le GPU disponibili. Può essere utilizzato per ridurre le dimensioni della raccolta delle tracce.
  • gpu_dump_graph_node_mapping: se abilitato, esegue il dump delle informazioni di mappatura dei nodi del grafico CUDA nella traccia. Il valore predefinito è False.

Ad esempio:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Restituisce InvalidArgumentError se vengono trovate chiavi o valori di opzione non riconosciuti.