Profilowanie obliczeń JAX za pomocą XProf

XProf to świetny sposób na uzyskiwanie i wizualizowanie śladów wydajności i profili programu, w tym aktywności na procesorach GPU i TPU. Wynik końcowy wygląda mniej więcej tak:

Przykład XProf

Automatyzacja

Możesz instrumentować kod, aby rejestrować ślad profilera dla kodu JAX za pomocą metod jax.profiler.start_tracejax.profiler.stop_trace. Wywołaj funkcję jax.profiler.start_trace z katalogiem, w którym mają być zapisywane pliki śledzenia. Powinien to być ten sam katalog --logdir, którego użyto do uruchomienia XProf. Następnie możesz użyć narzędzia XProf, aby wyświetlić ślady.

Na przykład, aby wykonać śledzenie profilera:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Zwróć uwagę na połączenie jax.block_until_ready. Używamy tego, aby mieć pewność, że ślad obejmuje wykonanie na urządzeniu. Więcej informacji o tym, dlaczego jest to konieczne, znajdziesz w sekcji Wysyłanie asynchroniczne.

Możesz też użyć menedżera kontekstu jax.profiler.trace zamiast start_tracestop_trace:

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Wyświetlanie logu czasu

Po zarejestrowaniu śladu możesz go wyświetlić w interfejsie XProf.

Interfejs profilera możesz uruchomić bezpośrednio za pomocą samodzielnego polecenia XProf, kierując je do katalogu logów:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Otwórz podany adres URL (np. http://localhost:8791/) w przeglądarce, aby wyświetlić profil.

Dostępne ślady pojawią się w menu „Sesje” po lewej stronie. Wybierz sesję, która Cię interesuje, a potem w menu „Narzędzia” kliknij „Przeglądarka śladów”. Powinna być teraz widoczna oś czasu wykonania. Do poruszania się po śladzie możesz używać klawiszy WASD, a aby wybrać zdarzenia i wyświetlić więcej szczegółów, klikaj lub przeciągaj. Więcej informacji o korzystaniu z przeglądarki śladów znajdziesz w dokumentacji narzędzia Trace Viewer.

Ręczne przechwytywanie za pomocą XProf

Poniżej znajdziesz instrukcje rejestrowania wywołanego ręcznie śladu trwającego N sekund z działającego programu.

  1. Uruchom serwer XProf:

    xprof --logdir /tmp/profile-data/
    

    XProf powinien być dostępny pod adresem <http://localhost:8791/>. Możesz określić inny port za pomocą flagi --port.

  2. W programie lub procesie w Pythonie, który chcesz profilować, dodaj w pobliżu początku ten kod:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Spowoduje to uruchomienie serwera profilera, z którym łączy się XProf. Zanim przejdziesz do następnego kroku, serwer profilera musi być uruchomiony. Gdy skończysz korzystać z serwera, możesz go wyłączyć, wywołując funkcję jax.profiler.stop_server().

    Jeśli chcesz profilować fragment długotrwałego programu (np. długą pętlę trenowania), możesz umieścić ten kod na początku programu i uruchomić go jak zwykle. Jeśli chcesz profilować krótki program (np. mikrobenczmark), możesz uruchomić serwer profilera w powłoce IPython i uruchomić krótki program za pomocą polecenia %run po rozpoczęciu rejestrowania w następnym kroku. Możesz też uruchomić serwer profilera na początku programu i użyć time.sleep(), aby mieć wystarczająco dużo czasu na rozpoczęcie rejestrowania.

  3. Otwórz <http://localhost:8791/> i w lewym górnym rogu kliknij przycisk „ZAPISZ PROFIL”. Jako adres URL usługi profilu wpisz „localhost:9999” (jest to adres serwera profilera, który został uruchomiony w poprzednim kroku). Wpisz liczbę milisekund, przez którą chcesz profilować, i kliknij „ZAPISZ”.

  4. Jeśli kod, który chcesz profilować, nie jest jeszcze uruchomiony (np. jeśli serwer profilera został uruchomiony w powłoce Pythona), uruchom go podczas przechwytywania.

  5. Po zakończeniu przechwytywania XProf powinien się automatycznie odświeżyć. (Nie wszystkie funkcje profilowania XProf są połączone z JAX, więc początkowo może się wydawać, że nic nie zostało zarejestrowane). Po lewej stronie w sekcji „Narzędzia” wybierz „Przeglądarka śladów”.

Powinna być teraz widoczna oś czasu wykonania. Do poruszania się po śladzie możesz używać klawiszy WASD. Aby wybrać zdarzenia i wyświetlić więcej szczegółów u dołu, kliknij je lub przeciągnij po nich wskaźnik. Więcej informacji o korzystaniu z przeglądarki śladów znajdziesz w dokumentacji narzędzia Trace Viewer.

XProf i TensorBoard

XProf to podstawowe narzędzie, które umożliwia profilowanie i rejestrowanie śladów w TensorBoardzie. Jeśli xprof jest zainstalowany, w TensorBoard będzie dostępna karta „Profil”. Używanie tej opcji jest identyczne z niezależnym uruchamianiem XProf, o ile jest ono uruchamiane z wskazaniem tego samego katalogu logów. Obejmuje to funkcje przechwytywania, analizowania i wyświetlania profili. XProf zastępuje funkcję tensorboard_plugin_profile, która była wcześniej zalecana.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Dodawanie niestandardowych zdarzeń śledzenia

Domyślnie zdarzenia w przeglądarce śladów to głównie wewnętrzne funkcje JAX niskiego poziomu. Możesz dodawać własne zdarzenia i funkcje, używając w kodzie elementów jax.profiler.TraceAnnotationjax.profiler.annotate_function.

Konfigurowanie opcji profilera

Metoda start_trace akceptuje opcjonalny parametr profiler_options, który umożliwia szczegółową kontrolę nad działaniem profilera. Ten parametr powinien być instancją klasy jax.profiler.ProfileOptions.

Aby na przykład wyłączyć wszystkie ślady Pythona i hosta:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Opcje ogólne

  1. host_tracer_level: ustawia poziom śledzenia aktywności po stronie hosta.

    Obsługiwane wartości:

    • 0: całkowicie wyłącza śledzenie hosta (procesora).
    • 1: włącza śledzenie tylko zdarzeń TraceMe, które zostały zaimplementowane przez użytkownika.
    • 2: obejmuje ślady poziomu 1 oraz szczegóły wykonania programu wysokiego poziomu, takie jak kosztowne operacje XLA (domyślnie).
    • 3: obejmuje ślady poziomu 2 oraz bardziej szczegółowe informacje o wykonywaniu programu na niskim poziomie, np. tanie operacje XLA.
  2. device_tracer_level: określa, czy śledzenie urządzenia jest włączone.

    Obsługiwane wartości:

    • 0: wyłącza śledzenie urządzenia.
    • 1: włącza śledzenie urządzenia (domyślnie).
  3. python_tracer_level: określa, czy śledzenie w Pythonie jest włączone.

    Obsługiwane wartości:

    • 0: wyłącza śledzenie wywołań funkcji Pythona (domyślnie).
    • 1: włącza śledzenie Pythona.

Zaawansowane opcje konfiguracji

Opcje TPU

  1. tpu_trace_mode: określa tryb śledzenia TPU.

    Obsługiwane wartości:

    • TRACE_ONLY_HOST: oznacza to, że śledzone są tylko działania po stronie hosta (procesora), a nie są zbierane żadne ślady z urządzenia (TPU/GPU).
    • TRACE_ONLY_XLA: oznacza to, że śledzone są tylko operacje na poziomie XLA na urządzeniu.
    • TRACE_COMPUTE: śledzi operacje obliczeniowe na urządzeniu.
    • TRACE_COMPUTE_AND_SYNC: śledzi zarówno operacje obliczeniowe, jak i zdarzenia synchronizacji na urządzeniu.

    Jeśli nie podasz wartości „tpu_trace_mode”, domyślnie używana jest wartość trace_mode TRACE_ONLY_XLA.

  2. tpu_num_sparse_cores_to_trace: określa liczbę rdzeni rzadkich do śledzenia w TPU.

  3. tpu_num_sparse_core_tiles_to_trace: określa liczbę kafelków w każdym rdzeniu rzadkim, które mają być śledzone w TPU.

  4. tpu_num_chips_to_profile_per_task: określa liczbę układów TPU do profilowania w przypadku każdego zadania.

Opcje GPU

Dostępne są te opcje profilowania GPU:

  • gpu_max_callback_api_events: Określa maksymalną liczbę zdarzeń zbieranych przez interfejs API wywołania zwrotnego CUPTI. Domyślna wartość to 2*1024*1024.
  • gpu_max_activity_api_events: określa maksymalną liczbę zdarzeń zbieranych przez interfejs CUPTI Activity API. Domyślna wartość to 2*1024*1024.
  • gpu_max_annotation_strings: określa maksymalną liczbę ciągów adnotacji, które można zebrać. Domyślna wartość to 1024*1024.
  • gpu_enable_nvtx_tracking: włącza śledzenie NVTX w CUPTI. Domyślna wartość to False.
  • gpu_enable_cupti_activity_graph_trace: włącza śledzenie wykresu aktywności CUPTI w przypadku wykresów CUDA. Domyślna wartość to False.
  • gpu_pm_sample_counters: ciąg znaków oddzielonych przecinkami zawierający wskaźniki monitorowania wydajności GPU, które mają być zbierane za pomocą funkcji próbkowania PM w CUPTI (np. "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Próbkowanie PM jest domyślnie wyłączone. Dostępne dane znajdziesz w dokumentacji CUPTI firmy NVIDIA.
  • gpu_pm_sample_interval_us: ustawia interwał próbkowania w mikrosekundach dla próbkowania CUPTI PM. Domyślna wartość to 500.
  • gpu_pm_sample_buffer_size_per_gpu_mb: określa rozmiar bufora pamięci systemowej na urządzenie w MB na potrzeby próbkowania CUPTI PM. Domyślna wartość to 64 MB. Maksymalna obsługiwana wartość to 4 GB.
  • gpu_num_chips_to_profile_per_task: określa liczbę urządzeń GPU, które mają być profilowane w ramach zadania. Jeśli nie zostanie podana, będzie miała wartość 0 lub nieprawidłową wartość, profilowanie obejmie wszystkie dostępne procesory GPU. Możesz go użyć, aby zmniejszyć rozmiar kolekcji śladów.
  • gpu_dump_graph_node_mapping: jeśli ta opcja jest włączona, do śladu są zapisywane informacje o mapowaniu węzłów wykresu CUDA. Domyślna wartość to False.

Na przykład:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Zwraca InvalidArgumentError, jeśli znaleziono nierozpoznane klucze lub wartości opcji.