Phân tích các phép tính JAX bằng XProf

XProf là một cách tuyệt vời để thu thập và trực quan hoá dấu vết hiệu suất cũng như hồ sơ của chương trình, bao gồm cả hoạt động trên GPU và TPU. Kết quả cuối cùng sẽ có dạng như sau:

Ví dụ về XProf

Chụp ảnh có lập trình

Bạn có thể đo lường mã để ghi lại dấu vết của trình phân tích tài nguyên cho mã JAX thông qua các phương thức jax.profiler.start_tracejax.profiler.stop_trace. Gọi jax.profiler.start_trace bằng thư mục để ghi các tệp theo dõi vào. Đây phải là cùng một thư mục --logdir được dùng để khởi động XProf. Sau đó, bạn có thể dùng XProf để xem các dấu vết.

Ví dụ: để lấy dấu vết của trình phân tích tài nguyên:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Lưu ý cuộc gọi jax.block_until_ready. Chúng tôi sử dụng thông tin này để đảm bảo dấu vết ghi lại được quá trình thực thi trên thiết bị. Hãy xem phần Gửi không đồng bộ để biết thông tin chi tiết về lý do cần thiết phải làm như vậy.

Bạn cũng có thể sử dụng trình quản lý bối cảnh jax.profiler.trace thay cho start_tracestop_trace:

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

Xem dấu vết

Sau khi ghi lại một dấu vết, bạn có thể xem dấu vết đó bằng giao diện người dùng XProf.

Bạn có thể chạy giao diện người dùng của trình phân tích tài nguyên trực tiếp bằng lệnh XProf độc lập bằng cách trỏ lệnh này đến thư mục nhật ký:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Truy cập vào URL được cung cấp (ví dụ: http://localhost:8791/) trong trình duyệt để xem hồ sơ.

Các dấu vết có sẵn sẽ xuất hiện trong trình đơn thả xuống "Phiên" ở bên trái. Chọn phiên mà bạn quan tâm, sau đó trong trình đơn thả xuống "Công cụ", hãy chọn "Trình xem dấu vết". Lúc này, bạn sẽ thấy một dòng thời gian thực thi. Bạn có thể dùng các phím WASD để di chuyển trong dấu vết, đồng thời nhấp hoặc kéo để chọn các sự kiện để xem thêm thông tin chi tiết. Hãy xem tài liệu về Công cụ Trình xem dấu vết để biết thêm thông tin chi tiết về cách sử dụng trình xem dấu vết.

Ghi lại theo cách thủ công qua XProf

Sau đây là hướng dẫn về cách ghi lại dấu vết N giây được kích hoạt theo cách thủ công từ một chương trình đang chạy.

  1. Khởi động một máy chủ XProf:

    xprof --logdir /tmp/profile-data/
    

    Bạn có thể tải XProf tại <http://localhost:8791/>. Bạn có thể chỉ định một cổng khác bằng cờ --port.

  2. Trong chương trình hoặc quy trình Python mà bạn muốn lập hồ sơ, hãy thêm đoạn mã sau vào đâu đó gần đầu:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    Thao tác này sẽ khởi động máy chủ hồ sơ mà XProf kết nối. Máy chủ lập hồ sơ phải đang chạy thì bạn mới có thể chuyển sang bước tiếp theo. Khi dùng xong máy chủ, bạn có thể gọi jax.profiler.stop_server() để tắt máy chủ.

    Nếu muốn lập hồ sơ cho một đoạn mã của chương trình chạy trong thời gian dài (ví dụ: một vòng lặp huấn luyện dài), bạn có thể đặt đoạn mã này ở đầu chương trình và bắt đầu chương trình như bình thường. Nếu bạn muốn lập hồ sơ cho một chương trình ngắn (ví dụ: một phép đo vi mô), thì một lựa chọn là khởi động máy chủ trình phân tích tài nguyên trong một trình bao IPython và chạy chương trình ngắn bằng %run sau khi bắt đầu quá trình ghi lại ở bước tiếp theo. Một lựa chọn khác là khởi động máy chủ hồ sơ vào đầu chương trình và sử dụng time.sleep() để có đủ thời gian bắt đầu quá trình ghi lại.

  3. Mở <http://localhost:8791/> rồi nhấp vào nút "CAPTURE PROFILE" (CHỤP HỒ SƠ) ở trên cùng bên trái. Nhập "localhost:9999" làm URL dịch vụ hồ sơ (đây là địa chỉ của máy chủ trình phân tích tài nguyên mà bạn đã khởi động ở bước trước). Nhập số mili giây mà bạn muốn lập hồ sơ, rồi nhấp vào "CAPTURE" (GHI LẠI).

  4. Nếu mã bạn muốn lập hồ sơ chưa chạy (ví dụ: nếu bạn đã khởi động máy chủ trình phân tích tài nguyên trong một trình bao Python), hãy chạy mã đó trong khi quá trình ghi đang diễn ra.

  5. Sau khi quá trình ghi lại kết thúc, XProf sẽ tự động làm mới. (Không phải tất cả các tính năng lập hồ sơ XProf đều được liên kết với JAX, vì vậy, ban đầu có thể bạn sẽ thấy như không có gì được ghi lại.) Ở bên trái, trong phần "Công cụ", hãy chọn "Trình xem dấu vết".

Lúc này, bạn sẽ thấy một dòng thời gian thực thi. Bạn có thể dùng các phím WASD để di chuyển dấu vết, đồng thời nhấp hoặc kéo để chọn các sự kiện nhằm xem thêm thông tin chi tiết ở dưới cùng. Hãy xem tài liệu về Công cụ Trình xem dấu vết để biết thêm thông tin chi tiết về cách sử dụng trình xem dấu vết.

XProf và Tensorboard

XProf là công cụ cơ bản hỗ trợ chức năng lập hồ sơ và ghi lại dấu vết trong Tensorboard. Miễn là bạn đã cài đặt xprof, thẻ "Hồ sơ" sẽ xuất hiện trong TensorBoard. Việc sử dụng công cụ này tương tự như việc khởi chạy XProf một cách độc lập, miễn là công cụ này được khởi chạy trỏ đến cùng một thư mục nhật ký. Quyền này bao gồm chức năng chụp, phân tích và xem hồ sơ. XProf thay thế chức năng tensorboard_plugin_profile mà trước đây được đề xuất.

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

Thêm sự kiện theo dõi tuỳ chỉnh

Theo mặc định, các sự kiện trong trình xem dấu vết chủ yếu là các hàm JAX nội bộ cấp thấp. Bạn có thể thêm các sự kiện và hàm của riêng mình bằng cách sử dụng jax.profiler.TraceAnnotationjax.profiler.annotate_function trong mã.

Định cấu hình các lựa chọn của trình phân tích tài nguyên

Phương thức start_trace chấp nhận một tham số profiler_options không bắt buộc, cho phép kiểm soát chi tiết hành vi của trình phân tích tài nguyên. Tham số này phải là một thực thể của jax.profiler.ProfileOptions.

Ví dụ: để tắt tất cả các dấu vết python và dấu vết máy chủ lưu trữ:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Tuỳ chọn chung

  1. host_tracer_level: Thiết lập cấp độ theo dõi cho các hoạt động phía máy chủ.

    Giá trị được hỗ trợ:

    • 0: Tắt hoàn toàn tính năng theo dõi máy chủ (CPU).
    • 1: Chỉ cho phép theo dõi các sự kiện TraceMe do người dùng đo lường.
    • 2: Bao gồm các dấu vết cấp 1 cùng với thông tin chi tiết về quá trình thực thi chương trình cấp cao, chẳng hạn như các thao tác XLA tốn kém (mặc định).
    • 3: Bao gồm các dấu vết cấp 2 cùng với thông tin chi tiết hơn về quá trình thực thi chương trình ở cấp thấp, chẳng hạn như các thao tác XLA có chi phí thấp.
  2. device_tracer_level: Kiểm soát việc có bật tính năng theo dõi thiết bị hay không.

    Giá trị được hỗ trợ:

    • 0: Tắt tính năng theo dõi thiết bị.
    • 1: Bật tính năng theo dõi thiết bị (mặc định).
  3. python_tracer_level: Kiểm soát việc có bật tính năng theo dõi Python hay không.

    Giá trị được hỗ trợ:

    • 0: Tắt tính năng theo dõi lệnh gọi hàm Python (mặc định).
    • 1: Bật tính năng theo dõi Python.

Tùy chọn cấu hình nâng cao

Các lựa chọn về TPU

  1. tpu_trace_mode: Chỉ định chế độ theo dõi TPU.

    Giá trị được hỗ trợ:

    • TRACE_ONLY_HOST: Điều này có nghĩa là chỉ các hoạt động phía máy chủ (CPU) được theo dõi và không có dấu vết nào trên thiết bị (TPU/GPU) được thu thập.
    • TRACE_ONLY_XLA: Điều này có nghĩa là chỉ các thao tác ở cấp XLA trên thiết bị mới được theo dõi.
    • TRACE_COMPUTE: Lựa chọn này theo dõi các hoạt động tính toán trên thiết bị.
    • TRACE_COMPUTE_AND_SYNC: Thao tác này theo dõi cả các hoạt động tính toán và sự kiện đồng bộ hoá trên thiết bị.

    Nếu bạn không cung cấp "tpu_trace_mode", thì trace_mode sẽ mặc định là TRACE_ONLY_XLA.

  2. tpu_num_sparse_cores_to_trace: Chỉ định số lượng lõi thưa để theo dõi trên TPU.

  3. tpu_num_sparse_core_tiles_to_trace: Chỉ định số lượng ô trong mỗi lõi thưa để theo dõi trên TPU.

  4. tpu_num_chips_to_profile_per_task: Chỉ định số lượng chip TPU cần lập hồ sơ cho mỗi tác vụ.

Các lựa chọn về GPU

Bạn có thể sử dụng các lựa chọn sau đây để lập hồ sơ GPU:

  • gpu_max_callback_api_events: Đặt số lượng sự kiện tối đa mà API gọi lại CUPTI thu thập. Giá trị mặc định là 2*1024*1024.
  • gpu_max_activity_api_events: Đặt số lượng sự kiện tối đa mà API hoạt động CUPTI thu thập. Giá trị mặc định là 2*1024*1024.
  • gpu_max_annotation_strings: Đặt số lượng tối đa chuỗi chú thích có thể thu thập. Giá trị mặc định là 1024*1024.
  • gpu_enable_nvtx_tracking: Cho phép theo dõi NVTX trong CUPTI. Giá trị mặc định là False.
  • gpu_enable_cupti_activity_graph_trace: Bật tính năng theo dõi biểu đồ hoạt động CUPTI cho biểu đồ CUDA. Giá trị mặc định là False.
  • gpu_pm_sample_counters: Một chuỗi được phân tách bằng dấu phẩy gồm các chỉ số Giám sát hiệu suất GPU để thu thập bằng tính năng lấy mẫu PM của CUPTI (ví dụ: "sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). Tính năng lấy mẫu PM bị tắt theo mặc định. Để biết các chỉ số có sẵn, hãy xem tài liệu CUPTI của NVIDIA.
  • gpu_pm_sample_interval_us: Đặt khoảng thời gian lấy mẫu tính bằng micrô giây cho hoạt động lấy mẫu CUPTI PM. Giá trị mặc định là 500.
  • gpu_pm_sample_buffer_size_per_gpu_mb: Đặt kích thước vùng đệm bộ nhớ hệ thống cho mỗi thiết bị (tính bằng MB) để lấy mẫu CUPTI PM. Giá trị mặc định là 64 MB. Giá trị tối đa được hỗ trợ là 4 GB.
  • gpu_num_chips_to_profile_per_task: Chỉ định số lượng thiết bị GPU cần phân tích tài nguyên cho mỗi tác vụ. Nếu bạn không chỉ định, đặt thành 0 hoặc đặt thành một giá trị không hợp lệ, thì tất cả GPU có sẵn sẽ được lập hồ sơ. Bạn có thể dùng cách này để giảm kích thước của hoạt động thu thập dấu vết.
  • gpu_dump_graph_node_mapping: Nếu được bật, thông tin ánh xạ nút biểu đồ CUDA sẽ được kết xuất vào dấu vết. Giá trị mặc định là False.

Ví dụ:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

Trả về InvalidArgumentError nếu tìm thấy bất kỳ khoá hoặc giá trị tuỳ chọn nào không nhận dạng được.