Tích hợp trình bổ trợ PJRT

Thông tin khái quát

PJRT là API thiết bị thống nhất mà chúng tôi muốn thêm vào hệ sinh thái học máy. Tầm nhìn dài hạn là: (1) các khung (JAX, TF, v.v.) sẽ gọi PJRT, có các triển khai dành riêng cho thiết bị được che khuất đối với khung; (2) mỗi thiết bị tập trung vào việc triển khai các API PJRT và có thể không rõ ràng đối với các khung.

Tài liệu này tập trung vào các đề xuất về cách tích hợp với PJRT và cách kiểm thử việc tích hợp PJRT với JAX.

Cách tích hợp với PJRT

Bước 1: Triển khai giao diện PJRT C API

Cách A: Bạn có thể triển khai trực tiếp PJRT C API.

Lựa chọn B: Nếu có thể tạo bản dựng dựa trên mã C++ trong kho lưu trữ xla (thông qua forking hoặc bazel), thì bạn cũng có thể triển khai API PJRT C++ và sử dụng trình bao bọc C→C++:

  1. Triển khai một ứng dụng PJRT C++ kế thừa từ ứng dụng PJRT cơ sở (và các lớp PJRT có liên quan). Dưới đây là một số ví dụ về ứng dụng PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Triển khai một số phương thức API C không thuộc ứng dụng C++ PJRT:
    • PJRT_Client_Create (Tạo). Dưới đây là một số mã giả mẫu (giả sử GetPluginPjRtClient trả về một ứng dụng C++ PJRT được triển khai ở trên):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Lưu ý: PJRT_Client_Create có thể lấy các tuỳ chọn được chuyển từ khung. Đây là ví dụ về cách ứng dụng GPU sử dụng tính năng này.

Với trình bao bọc, bạn không cần triển khai các API C còn lại.

Bước 2: Triển khai GetPjRtApi

Bạn cần triển khai một phương thức GetPjRtApi để trả về một PJRT_Api* chứa các con trỏ hàm cho các hoạt động triển khai API PJRT C. Dưới đây là một ví dụ giả định việc triển khai thông qua trình bao bọc (tương tự như pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Bước 3: Kiểm thử việc triển khai API C

Bạn có thể gọi RegisterPjRtCApiTestFactory để chạy một tập hợp nhỏ các bài kiểm thử cho các hành vi cơ bản của API PJRT C.

Cách sử dụng trình bổ trợ PJRT của JAX

Bước 1: Thiết lập JAX

Bạn có thể sử dụng JAX mỗi đêm

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

hoặc tạo JAX từ nguồn.

Hiện tại, bạn cần so khớp phiên bản jaxlib với phiên bản PJRT C API. Thông thường, bạn chỉ nên sử dụng phiên bản jaxlib ban đêm từ cùng ngày với cam kết TF mà bạn đang xây dựng trình bổ trợ, ví dụ:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Bạn cũng có thể tạo jaxlib từ nguồn theo đúng cam kết XLA mà bạn đang xây dựng (instructions).

Chúng tôi sẽ sớm bắt đầu hỗ trợ khả năng tương thích với ABI.

Bước 2: Sử dụng không gian tên jax_plugins hoặc thiết lập enter_point

Có hai cách để JAX phát hiện trình bổ trợ của bạn.

  1. Sử dụng gói không gian tên (ref). Định nghĩa một mô-đun duy nhất trên toàn hệ thống trong gói không gian tên jax_plugins (tức là chỉ cần tạo một thư mục jax_plugins và xác định mô-đun bên dưới mô-đun đó). Dưới đây là ví dụ về cấu trúc thư mục:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Sử dụng siêu dữ liệu của gói (ref). Nếu xây dựng một gói thông qua pyproject.toml hoặc setup.py, hãy quảng cáo tên mô-đun trình bổ trợ của bạn bằng cách đưa một điểm truy cập vào trong nhóm jax_plugins trỏ đến tên mô-đun đầy đủ của bạn. Dưới đây là ví dụ thông qua pyproject.toml hoặc setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Dưới đây là các ví dụ về cách triển khai openxla-pjrt-plugin bằng cách sử dụng Cách 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Bước 3: Triển khai phương thức launch()

Bạn cần triển khai phương thứcinitialize() trong mô-đun python của mình để đăng ký trình bổ trợ, ví dụ:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Vui lòng tham khảo tại đây để biết cách sử dụng xla_bridge.register_plugin. Hiện tại, đây là một phương thức riêng tư. Một API công khai sẽ được phát hành trong tương lai.

Bạn có thể chạy dòng bên dưới để xác minh rằng trình bổ trợ đã được đăng ký và báo lỗi nếu không tải được.

jax.config.update("jax_platforms", "my_plugin")

JAX có thể có nhiều phần phụ trợ/trình bổ trợ. Có một số cách giúp đảm bảo trình bổ trợ của bạn được dùng làm phần phụ trợ mặc định:

  • Lựa chọn 1: chạy jax.config.update("jax_platforms", "my_plugin") ở đầu chương trình.
  • Cách 2: đặt ENV JAX_PLATFORMS=my_plugin.
  • Lựa chọn 3: đặt mức độ ưu tiên đủ cao khi gọi xb.register_plugin (giá trị mặc định là 400, cao hơn các phần phụ trợ hiện có khác). Xin lưu ý rằng phần phụ trợ có mức độ ưu tiên cao nhất sẽ chỉ được dùng khi JAX_PLATFORMS=''. Giá trị mặc định của JAX_PLATFORMS'' nhưng đôi khi sẽ bị ghi đè.

Cách kiểm thử bằng JAX

Một số trường hợp kiểm thử cơ bản nên thử:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Chúng tôi sẽ sớm bổ sung hướng dẫn để chạy kiểm thử đơn vị jax đối với trình bổ trợ của bạn!)

Ví dụ: Trình bổ trợ JAX CUDA

  1. Triển khai API PJRT C thông qua trình bao bọc (pjrt_c_api_gpu.h).
  2. Thiết lập điểm truy cập cho gói (setup.py).
  3. Triển khai phương thức launch() (__init__.py).
  4. Có thể kiểm thử với bất kỳ kiểm thử jax nào cho CUDA. ```