XProf를 사용한 JAX 계산 프로파일링

XProf는 GPU 및 TPU의 활동을 비롯한 프로그램의 성능 트레이스 및 프로필을 획득하고 시각화하는 좋은 방법입니다. 최종 결과는 다음과 같습니다.

XProf 예시

프로그래매틱 캡처

jax.profiler.start_tracejax.profiler.stop_trace 메서드를 통해 JAX 코드의 프로파일러 트레이스를 캡처하도록 코드를 계측할 수 있습니다. 추적 파일을 쓸 디렉터리를 사용하여 jax.profiler.start_trace를 호출합니다. XProf를 시작하는 데 사용한 것과 동일한 --logdir 디렉터리여야 합니다. 그런 다음 XProf를 사용하여 트레이스를 볼 수 있습니다.

예를 들어 프로파일러 트레이스를 가져오려면 다음을 실행합니다.

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

jax.block_until_ready 호출을 확인합니다. 온디바이스 실행이 트레이스에 포착되도록 하는 데 사용됩니다. 이것이 필요한 이유에 관한 자세한 내용은 비동기 디스패치를 참고하세요.

start_tracestop_trace 대신 jax.profiler.trace 컨텍스트 관리자를 사용할 수도 있습니다.

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

trace 보기

트레이스를 캡처한 후 XProf UI를 사용하여 트레이스를 볼 수 있습니다.

로그 디렉터리를 가리키는 독립형 XProf 명령어를 사용하여 프로파일러 UI를 직접 실행할 수 있습니다.

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

제공된 URL (예: http://localhost:8791/)을 클릭하여 프로필을 확인합니다.

사용 가능한 트레이스는 왼쪽에 있는 '세션' 드롭다운 메뉴에 표시됩니다. 관심 있는 세션을 선택한 다음 '도구' 드롭다운에서 'Trace Viewer'를 선택합니다. 이제 실행 타임라인이 표시됩니다. WASD 키를 사용하여 트레이스를 탐색하고 클릭하거나 드래그하여 이벤트에 관한 자세한 내용을 확인할 수 있습니다. Trace Viewer 사용에 관한 자세한 내용은 Trace Viewer 도구 문서를 참고하세요.

XProf를 통한 수동 캡처

다음은 실행 중인 프로그램에서 수동으로 트리거된 N초 트레이스를 캡처하는 방법입니다.

  1. XProf 서버를 시작합니다.

    xprof --logdir /tmp/profile-data/
    

    <http://localhost:8791/>에서 XProf를 로드할 수 있습니다. --port 플래그를 사용하여 다른 포트를 지정할 수 있습니다.

  2. 프로파일링하려는 Python 프로그램 또는 프로세스에서 시작 부분 근처에 다음을 추가합니다.

    import jax.profiler
    jax.profiler.start_server(9999)
    

    이렇게 하면 XProf가 연결되는 프로파일러 서버가 시작됩니다. 다음 단계로 이동하기 전에 프로파일러 서버가 실행 중이어야 합니다. 서버 사용이 끝나면 jax.profiler.stop_server()를 호출하여 서버를 종료할 수 있습니다.

    장기 실행 프로그램 (예: 긴 학습 루프)의 스니펫을 프로파일링하려면 프로그램 시작 부분에 이를 배치하고 평소와 같이 프로그램을 시작하면 됩니다. 짧은 프로그램 (예: 마이크로벤치마크)을 프로파일링하려면 IPython 셸에서 프로파일러 서버를 시작하고 다음 단계에서 캡처를 시작한 후 %run로 짧은 프로그램을 실행하면 됩니다. 또 다른 옵션은 프로그램 시작 시 프로파일러 서버를 시작하고 time.sleep()를 사용하여 캡처를 시작할 충분한 시간을 확보하는 것입니다.

  3. <http://localhost:8791/>를 열고 왼쪽 상단의 'CAPTURE PROFILE'(프로필 캡처) 버튼을 클릭합니다. 프로필 서비스 URL로 'localhost:9999'를 입력합니다 (이전 단계에서 시작한 프로파일러 서버의 주소임). 프로파일링할 밀리초 수를 입력하고 'CAPTURE'를 클릭합니다.

  4. 프로파일링하려는 코드가 아직 실행되고 있지 않다면 (예: Python 셸에서 프로파일러 서버를 시작한 경우) 캡처가 실행되는 동안 코드를 실행합니다.

  5. 캡처가 완료되면 XProf가 자동으로 새로고침됩니다. (일부 XProf 프로파일링 기능은 JAX와 연결되어 있지 않으므로 처음에는 캡처된 항목이 없는 것처럼 보일 수 있습니다.) 왼쪽의 '도구'에서 'Trace Viewer'를 선택합니다.

이제 실행 타임라인이 표시됩니다. WASD 키를 사용하여 트레이스를 탐색하고 클릭하거나 드래그하여 이벤트를 선택하여 하단에서 자세한 내용을 확인할 수 있습니다. trace 뷰어 사용에 관한 자세한 내용은 Trace Viewer 도구 문서를 참고하세요.

XProf 및 Tensorboard

XProf는 텐서보드에서 프로파일링 및 트레이스 캡처 기능을 지원하는 기본 도구입니다. xprof가 설치되어 있으면 Tensorboard 내에 '프로필' 탭이 표시됩니다. 동일한 로그 디렉터리를 가리키도록 실행되는 한, 이를 사용하는 것은 XProf를 독립적으로 실행하는 것과 동일합니다. 여기에는 프로필 캡처, 분석, 보기 기능이 포함됩니다. XProf는 이전에 권장되었던 tensorboard_plugin_profile 기능을 대체합니다.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

맞춤 추적 이벤트 추가

기본적으로 트레이스 뷰어의 이벤트는 대부분 하위 수준의 내부 JAX 함수입니다. 코드에서 jax.profiler.TraceAnnotationjax.profiler.annotate_function를 사용하여 자체 이벤트와 함수를 추가할 수 있습니다.

프로파일러 옵션 구성

start_trace 메서드는 프로파일러의 동작을 세밀하게 제어할 수 있는 선택적 profiler_options 매개변수를 허용합니다. 이 매개변수는 jax.profiler.ProfileOptions의 인스턴스여야 합니다.

예를 들어 모든 Python 및 호스트 트레이스를 사용 중지하려면 다음을 실행합니다.

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

일반 옵션

  1. host_tracer_level: 호스트 측 활동의 추적 수준을 설정합니다.

    지원되는 값:

    • 0: 호스트 (CPU) 추적을 완전히 사용 중지합니다.
    • 1: 사용자 계측 TraceMe 이벤트만 추적할 수 있습니다.
    • 2: 수준 1 트레이스와 함께 비용이 많이 드는 XLA 작업과 같은 대략적인 프로그램 실행 세부정보를 포함합니다 (기본값).
    • 3: 레벨 2 트레이스와 저렴한 XLA 작업과 같은 더 자세한 하위 수준 프로그램 실행 세부정보를 포함합니다.
  2. device_tracer_level: 기기 추적 사용 설정 여부를 제어합니다.

    지원되는 값:

    • 0: 기기 추적을 사용 중지합니다.
    • 1: 기기 추적을 사용 설정합니다 (기본값).
  3. python_tracer_level: Python 추적을 사용 설정할지 여부를 제어합니다.

    지원되는 값:

    • 0: Python 함수 호출 추적을 사용 중지합니다 (기본값).
    • 1: Python 추적을 사용 설정합니다.

고급 구성 옵션

TPU 옵션

  1. tpu_trace_mode: TPU 추적 모드를 지정합니다.

    지원되는 값:

    • TRACE_ONLY_HOST: 호스트 측 (CPU) 활동만 추적되고 기기 (TPU/GPU) 트레이스는 수집되지 않습니다.
    • TRACE_ONLY_XLA: 기기에서 XLA 수준 작업만 추적됩니다.
    • TRACE_COMPUTE: 기기에서 컴퓨팅 작업을 추적합니다.
    • TRACE_COMPUTE_AND_SYNC: 기기에서 컴퓨팅 작업과 동기화 이벤트를 모두 추적합니다.

    'tpu_trace_mode'가 제공되지 않으면 trace_mode는 기본적으로 TRACE_ONLY_XLA입니다.

  2. tpu_num_sparse_cores_to_trace: TPU에서 추적할 스파스 코어 수를 지정합니다.

  3. tpu_num_sparse_core_tiles_to_trace: TPU에서 추적할 각 스파스 코어 내 타일 수를 지정합니다.

  4. tpu_num_chips_to_profile_per_task: 작업당 프로파일링할 TPU 칩 수를 지정합니다.

GPU 옵션

GPU 프로파일링에 사용할 수 있는 옵션은 다음과 같습니다.

  • gpu_max_callback_api_events: CUPTI 콜백 API에서 수집하는 이벤트의 최대 수를 설정합니다. 기본값은 2*1024*1024입니다.
  • gpu_max_activity_api_events: CUPTI 활동 API에서 수집하는 이벤트의 최대 개수를 설정합니다. 기본값은 2*1024*1024입니다.
  • gpu_max_annotation_strings: 수집할 수 있는 주석 문자열의 최대 개수를 설정합니다. 기본값은 1024*1024입니다.
  • gpu_enable_nvtx_tracking: CUPTI에서 NVTX 추적을 사용 설정합니다. 기본값은 False입니다.
  • gpu_enable_cupti_activity_graph_trace: CUDA 그래프의 CUPTI 활동 그래프 추적을 사용 설정합니다. 기본값은 False입니다.
  • gpu_pm_sample_counters: CUPTI의 PM 샘플링 기능을 사용하여 수집할 GPU 성능 모니터링 측정항목의 쉼표로 구분된 문자열입니다 (예: "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). PM 샘플링은 기본적으로 사용 중지되어 있습니다. 사용 가능한 측정항목은 NVIDIA CUPTI 문서를 참고하세요.
  • gpu_pm_sample_interval_us: CUPTI PM 샘플링의 샘플링 간격을 마이크로초 단위로 설정합니다. 기본값은 500입니다.
  • gpu_pm_sample_buffer_size_per_gpu_mb: CUPTI PM 샘플링을 위해 기기당 시스템 메모리 버퍼 크기를 MB 단위로 설정합니다. 기본값은 64MB입니다. 지원되는 최댓값은 4GB입니다.
  • gpu_num_chips_to_profile_per_task: 작업당 프로파일링할 GPU 기기 수를 지정합니다. 지정하지 않거나 0으로 설정하거나 유효하지 않은 값으로 설정하면 사용 가능한 모든 GPU가 프로파일링됩니다. 이를 사용하여 트레이스 수집 크기를 줄일 수 있습니다.
  • gpu_dump_graph_node_mapping: 사용 설정된 경우 CUDA 그래프 노드 매핑 정보를 트레이스에 덤프합니다. 기본값은 False입니다.

예를 들면 다음과 같습니다.

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

인식할 수 없는 키나 옵션 값이 발견되면 InvalidArgumentError를 반환합니다.