XProf to świetny sposób na uzyskiwanie i wizualizowanie śladów wydajności i profili programu, w tym aktywności na procesorach GPU i TPU. Wynik końcowy wygląda mniej więcej tak:

Automatyzacja
Możesz instrumentować kod, aby rejestrować ślad profilera dla kodu JAX za pomocą metod jax.profiler.start_trace i jax.profiler.stop_trace. Wywołaj funkcję
jax.profiler.start_trace
z katalogiem, w którym mają być zapisywane pliki śledzenia. Powinien to być ten sam katalog --logdir, którego użyto do uruchomienia XProf. Następnie możesz użyć narzędzia XProf, aby wyświetlić ślady.
Na przykład, aby wykonać śledzenie profilera:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Zwróć uwagę na połączenie jax.block_until_ready. Używamy tego, aby mieć pewność, że ślad obejmuje wykonanie na urządzeniu. Więcej informacji o tym, dlaczego jest to konieczne, znajdziesz w sekcji Wysyłanie asynchroniczne.
Możesz też użyć menedżera kontekstu jax.profiler.trace zamiast start_trace i stop_trace:
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Wyświetlanie logu czasu
Po zarejestrowaniu śladu możesz go wyświetlić w interfejsie XProf.
Interfejs profilera możesz uruchomić bezpośrednio za pomocą samodzielnego polecenia XProf, kierując je do katalogu logów:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Otwórz podany adres URL (np. http://localhost:8791/) w przeglądarce, aby wyświetlić profil.
Dostępne ślady pojawią się w menu „Sesje” po lewej stronie. Wybierz sesję, która Cię interesuje, a potem w menu „Narzędzia” kliknij „Przeglądarka śladów”. Powinna być teraz widoczna oś czasu wykonania. Do poruszania się po śladzie możesz używać klawiszy WASD, a aby wybrać zdarzenia i wyświetlić więcej szczegółów, klikaj lub przeciągaj. Więcej informacji o korzystaniu z przeglądarki śladów znajdziesz w dokumentacji narzędzia Trace Viewer.
Ręczne przechwytywanie za pomocą XProf
Poniżej znajdziesz instrukcje rejestrowania wywołanego ręcznie śladu trwającego N sekund z działającego programu.
Uruchom serwer XProf:
xprof --logdir /tmp/profile-data/XProf powinien być dostępny pod adresem
<http://localhost:8791/>. Możesz określić inny port za pomocą flagi--port.W programie lub procesie w Pythonie, który chcesz profilować, dodaj w pobliżu początku ten kod:
import jax.profiler jax.profiler.start_server(9999)Spowoduje to uruchomienie serwera profilera, z którym łączy się XProf. Zanim przejdziesz do następnego kroku, serwer profilera musi być uruchomiony. Gdy skończysz korzystać z serwera, możesz go wyłączyć, wywołując funkcję
jax.profiler.stop_server().Jeśli chcesz profilować fragment długotrwałego programu (np. długą pętlę trenowania), możesz umieścić ten kod na początku programu i uruchomić go jak zwykle. Jeśli chcesz profilować krótki program (np. mikrobenczmark), możesz uruchomić serwer profilera w powłoce IPython i uruchomić krótki program za pomocą polecenia
%runpo rozpoczęciu rejestrowania w następnym kroku. Możesz też uruchomić serwer profilera na początku programu i użyćtime.sleep(), aby mieć wystarczająco dużo czasu na rozpoczęcie rejestrowania.Otwórz
<http://localhost:8791/>i w lewym górnym rogu kliknij przycisk „ZAPISZ PROFIL”. Jako adres URL usługi profilu wpisz „localhost:9999” (jest to adres serwera profilera, który został uruchomiony w poprzednim kroku). Wpisz liczbę milisekund, przez którą chcesz profilować, i kliknij „ZAPISZ”.Jeśli kod, który chcesz profilować, nie jest jeszcze uruchomiony (np. jeśli serwer profilera został uruchomiony w powłoce Pythona), uruchom go podczas przechwytywania.
Po zakończeniu przechwytywania XProf powinien się automatycznie odświeżyć. (Nie wszystkie funkcje profilowania XProf są połączone z JAX, więc początkowo może się wydawać, że nic nie zostało zarejestrowane). Po lewej stronie w sekcji „Narzędzia” wybierz „Przeglądarka śladów”.
Powinna być teraz widoczna oś czasu wykonania. Do poruszania się po śladzie możesz używać klawiszy WASD. Aby wybrać zdarzenia i wyświetlić więcej szczegółów u dołu, kliknij je lub przeciągnij po nich wskaźnik. Więcej informacji o korzystaniu z przeglądarki śladów znajdziesz w dokumentacji narzędzia Trace Viewer.
XProf i TensorBoard
XProf to podstawowe narzędzie, które umożliwia profilowanie i rejestrowanie śladów w TensorBoardzie. Jeśli xprof jest zainstalowany, w TensorBoard będzie dostępna karta „Profil”. Używanie tej opcji jest identyczne z niezależnym uruchamianiem XProf, o ile jest ono uruchamiane z wskazaniem tego samego katalogu logów.
Obejmuje to funkcje przechwytywania, analizowania i wyświetlania profili. XProf zastępuje funkcję tensorboard_plugin_profile, która była wcześniej zalecana.
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Dodawanie niestandardowych zdarzeń śledzenia
Domyślnie zdarzenia w przeglądarce śladów to głównie wewnętrzne funkcje JAX niskiego poziomu. Możesz dodawać własne zdarzenia i funkcje, używając w kodzie elementów jax.profiler.TraceAnnotation i jax.profiler.annotate_function.
Konfigurowanie opcji profilera
Metoda start_trace akceptuje opcjonalny parametr profiler_options, który umożliwia szczegółową kontrolę nad działaniem profilera. Ten parametr powinien być instancją klasy jax.profiler.ProfileOptions.
Aby na przykład wyłączyć wszystkie ślady Pythona i hosta:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Opcje ogólne
host_tracer_level: ustawia poziom śledzenia aktywności po stronie hosta.Obsługiwane wartości:
0: całkowicie wyłącza śledzenie hosta (procesora).1: włącza śledzenie tylko zdarzeń TraceMe, które zostały zaimplementowane przez użytkownika.2: obejmuje ślady poziomu 1 oraz szczegóły wykonania programu wysokiego poziomu, takie jak kosztowne operacje XLA (domyślnie).3: obejmuje ślady poziomu 2 oraz bardziej szczegółowe informacje o wykonywaniu programu na niskim poziomie, np. tanie operacje XLA.
device_tracer_level: określa, czy śledzenie urządzenia jest włączone.Obsługiwane wartości:
0: wyłącza śledzenie urządzenia.1: włącza śledzenie urządzenia (domyślnie).
python_tracer_level: określa, czy śledzenie w Pythonie jest włączone.Obsługiwane wartości:
0: wyłącza śledzenie wywołań funkcji Pythona (domyślnie).1: włącza śledzenie Pythona.
Zaawansowane opcje konfiguracji
Opcje TPU
tpu_trace_mode: określa tryb śledzenia TPU.Obsługiwane wartości:
TRACE_ONLY_HOST: oznacza to, że śledzone są tylko działania po stronie hosta (procesora), a nie są zbierane żadne ślady z urządzenia (TPU/GPU).TRACE_ONLY_XLA: oznacza to, że śledzone są tylko operacje na poziomie XLA na urządzeniu.TRACE_COMPUTE: śledzi operacje obliczeniowe na urządzeniu.TRACE_COMPUTE_AND_SYNC: śledzi zarówno operacje obliczeniowe, jak i zdarzenia synchronizacji na urządzeniu.
Jeśli nie podasz wartości „tpu_trace_mode”, domyślnie używana jest wartość trace_mode
TRACE_ONLY_XLA.tpu_num_sparse_cores_to_trace: określa liczbę rdzeni rzadkich do śledzenia w TPU.tpu_num_sparse_core_tiles_to_trace: określa liczbę kafelków w każdym rdzeniu rzadkim, które mają być śledzone w TPU.tpu_num_chips_to_profile_per_task: określa liczbę układów TPU do profilowania w przypadku każdego zadania.
Opcje GPU
Dostępne są te opcje profilowania GPU:
gpu_max_callback_api_events: Określa maksymalną liczbę zdarzeń zbieranych przez interfejs API wywołania zwrotnego CUPTI. Domyślna wartość to2*1024*1024.gpu_max_activity_api_events: określa maksymalną liczbę zdarzeń zbieranych przez interfejs CUPTI Activity API. Domyślna wartość to2*1024*1024.gpu_max_annotation_strings: określa maksymalną liczbę ciągów adnotacji, które można zebrać. Domyślna wartość to1024*1024.gpu_enable_nvtx_tracking: włącza śledzenie NVTX w CUPTI. Domyślna wartość toFalse.gpu_enable_cupti_activity_graph_trace: włącza śledzenie wykresu aktywności CUPTI w przypadku wykresów CUDA. Domyślna wartość toFalse.gpu_pm_sample_counters: ciąg znaków oddzielonych przecinkami zawierający wskaźniki monitorowania wydajności GPU, które mają być zbierane za pomocą funkcji próbkowania PM w CUPTI (np."sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Próbkowanie PM jest domyślnie wyłączone. Dostępne dane znajdziesz w dokumentacji CUPTI firmy NVIDIA.gpu_pm_sample_interval_us: ustawia interwał próbkowania w mikrosekundach dla próbkowania CUPTI PM. Domyślna wartość to500.gpu_pm_sample_buffer_size_per_gpu_mb: określa rozmiar bufora pamięci systemowej na urządzenie w MB na potrzeby próbkowania CUPTI PM. Domyślna wartość to 64 MB. Maksymalna obsługiwana wartość to 4 GB.gpu_num_chips_to_profile_per_task: określa liczbę urządzeń GPU, które mają być profilowane w ramach zadania. Jeśli nie zostanie podana, będzie miała wartość 0 lub nieprawidłową wartość, profilowanie obejmie wszystkie dostępne procesory GPU. Możesz go użyć, aby zmniejszyć rozmiar kolekcji śladów.gpu_dump_graph_node_mapping: jeśli ta opcja jest włączona, do śladu są zapisywane informacje o mapowaniu węzłów wykresu CUDA. Domyślna wartość toFalse.
Na przykład:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Zwraca InvalidArgumentError, jeśli znaleziono nierozpoznane klucze lub wartości opcji.