Tích hợp trình bổ trợ PJRT

Tài liệu này tập trung vào các đề xuất về cách tích hợp với PJRT và cách kiểm thử hoạt động tích hợp PJRT với JAX.

Cách tích hợp với PJRT

Bước 1: Triển khai giao diện API PJRT C

Cách A: Bạn có thể triển khai trực tiếp API PJRT C.

Tuỳ chọn B: Nếu có thể tạo bản dựng dựa trên mã C++ trong kho lưu trữ xla (thông qua tính năng phân nhánh hoặc bazel), bạn cũng có thể triển khai API C++ PJRT và sử dụng trình bao bọc C→C++:

  1. Triển khai ứng dụng PJRT C++ kế thừa từ ứng dụng PJRT cơ sở (và các lớp PJRT có liên quan). Dưới đây là một số ví dụ về ứng dụng PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
  2. Triển khai một số phương thức API C không thuộc ứng dụng PJRT C++:
    • PJRT_Client_Create. Dưới đây là một số mã giả lập mẫu (giả sử GetPluginPjRtClient trả về một ứng dụng PJRT C++ được triển khai ở trên):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

Lưu ý PJRT_Client_Create có thể nhận các tuỳ chọn được truyền từ khung. Dưới đây là ví dụ về cách ứng dụng GPU sử dụng tính năng này.

Với trình bao bọc, bạn không cần triển khai các API C còn lại.

Bước 2: Triển khai GetPjRtApi

Bạn cần triển khai một phương thức GetPjRtApi trả về một PJRT_Api* chứa con trỏ hàm đến các phương thức triển khai API PJRT C. Dưới đây là ví dụ giả định triển khai thông qua trình bao bọc (tương tự như pjrt_c_api_cpu.cc):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

Bước 3: Kiểm thử việc triển khai API C

Bạn có thể gọi RegisterPjRtCApiTestFactory để chạy một nhóm nhỏ các chương trình kiểm thử cho các hành vi cơ bản của API PJRT C.

Cách sử dụng trình bổ trợ PJRT từ JAX

Bước 1: Thiết lập JAX

Bạn có thể sử dụng JAX hằng đêm

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

hoặc tạo JAX từ nguồn.

Hiện tại, bạn cần so khớp phiên bản jaxlib với phiên bản API PJRT C. Thông thường, bạn chỉ cần sử dụng phiên bản jaxlib hằng đêm từ cùng ngày với thay đổi cam kết TF mà bạn đang xây dựng trình bổ trợ, ví dụ:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Bạn cũng có thể tạo một jaxlib từ nguồn tại chính thay đổi XLA mà bạn đang tạo (hướng dẫn).

Chúng tôi sẽ sớm bắt đầu hỗ trợ khả năng tương thích ABI.

Bước 2: Sử dụng không gian tên jax_plugins hoặc thiết lập entry_point

Có hai cách để JAX phát hiện trình bổ trợ của bạn.

  1. Sử dụng các gói không gian tên (tham chiếu). Xác định một mô-đun duy nhất trên toàn hệ thống trong gói không gian tên jax_plugins (tức là chỉ cần tạo một thư mục jax_plugins và xác định mô-đun của bạn bên dưới thư mục đó). Dưới đây là ví dụ về cấu trúc thư mục:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Sử dụng siêu dữ liệu gói (tham chiếu). Nếu tạo gói thông qua pyproject.toml hoặc setup.py, hãy quảng cáo tên mô-đun trình bổ trợ bằng cách thêm một điểm truy cập trong nhóm jax_plugins trỏ đến tên mô-đun đầy đủ. Dưới đây là ví dụ thông qua pyproject.toml hoặc setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

Sau đây là ví dụ về cách triển khai openxla-pjrt-plugin bằng cách sử dụng Tuỳ chọn 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

Bước 3: Triển khai phương thức initialize()

Bạn cần triển khai phương thức initialize() trong mô-đun python để đăng ký trình bổ trợ, ví dụ:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

Vui lòng tham khảo tại đây để biết cách sử dụng xla_bridge.register_plugin. Phương thức này hiện là một phương thức riêng tư. API công khai sẽ được phát hành trong tương lai.

Bạn có thể chạy dòng bên dưới để xác minh rằng trình bổ trợ đã được đăng ký và báo lỗi nếu không tải được trình bổ trợ.

jax.config.update("jax_platforms", "my_plugin")

JAX có thể có nhiều phần phụ trợ/trình bổ trợ. Có một số cách để đảm bảo trình bổ trợ của bạn được dùng làm phần phụ trợ mặc định:

  • Cách 1: chạy jax.config.update("jax_platforms", "my_plugin") ở đầu chương trình.
  • Cách 2: đặt ENV JAX_PLATFORMS=my_plugin.
  • Cách 3: đặt mức độ ưu tiên đủ cao khi gọi xb.register_plugin (giá trị mặc định là 400, cao hơn các phần phụ trợ hiện có khác). Lưu ý rằng phần phụ trợ có mức độ ưu tiên cao nhất sẽ chỉ được sử dụng khi JAX_PLATFORMS=''. Giá trị mặc định của JAX_PLATFORMS'' nhưng đôi khi giá trị này sẽ bị ghi đè.

Cách kiểm thử bằng JAX

Một số trường hợp kiểm thử cơ bản để thử:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Chúng tôi sẽ sớm thêm hướng dẫn để chạy kiểm thử đơn vị jax trên trình bổ trợ của bạn!)

Để xem thêm ví dụ về trình bổ trợ PJRT, hãy xem phần Ví dụ về PJRT.