Tài liệu này tập trung vào các đề xuất về cách tích hợp với PJRT và cách kiểm thử hoạt động tích hợp PJRT với JAX.
Cách tích hợp với PJRT
Bước 1: Triển khai giao diện API PJRT C
Cách A: Bạn có thể triển khai trực tiếp API PJRT C.
Tuỳ chọn B: Nếu có thể tạo bản dựng dựa trên mã C++ trong kho lưu trữ xla (thông qua tính năng phân nhánh hoặc bazel), bạn cũng có thể triển khai API C++ PJRT và sử dụng trình bao bọc C→C++:
- Triển khai ứng dụng PJRT C++ kế thừa từ ứng dụng PJRT cơ sở (và các lớp PJRT có liên quan). Dưới đây là một số ví dụ về ứng dụng PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Triển khai một số phương thức API C không thuộc ứng dụng PJRT C++:
- PJRT_Client_Create. Dưới đây là một số mã giả lập mẫu (giả sử
GetPluginPjRtClient
trả về một ứng dụng PJRT C++ được triển khai ở trên):
- PJRT_Client_Create. Dưới đây là một số mã giả lập mẫu (giả sử
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Lưu ý PJRT_Client_Create có thể nhận các tuỳ chọn được truyền từ khung. Dưới đây là ví dụ về cách ứng dụng GPU sử dụng tính năng này.
- [Không bắt buộc] PJRT_TopologyDescription_Create.
- [Không bắt buộc] PJRT_Plugin_Initialize. Đây là chế độ thiết lập trình bổ trợ một lần. Khung sẽ gọi chế độ này trước khi gọi bất kỳ hàm nào khác.
- [Không bắt buộc] PJRT_Plugin_Attributes.
Với trình bao bọc, bạn không cần triển khai các API C còn lại.
Bước 2: Triển khai GetPjRtApi
Bạn cần triển khai một phương thức GetPjRtApi
trả về một PJRT_Api*
chứa con trỏ hàm đến các phương thức triển khai API PJRT C. Dưới đây là ví dụ giả định triển khai thông qua trình bao bọc (tương tự như pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Bước 3: Kiểm thử việc triển khai API C
Bạn có thể gọi RegisterPjRtCApiTestFactory để chạy một nhóm nhỏ các chương trình kiểm thử cho các hành vi cơ bản của API PJRT C.
Cách sử dụng trình bổ trợ PJRT từ JAX
Bước 1: Thiết lập JAX
Bạn có thể sử dụng JAX hằng đêm
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
hoặc tạo JAX từ nguồn.
Hiện tại, bạn cần so khớp phiên bản jaxlib với phiên bản API PJRT C. Thông thường, bạn chỉ cần sử dụng phiên bản jaxlib hằng đêm từ cùng ngày với thay đổi cam kết TF mà bạn đang xây dựng trình bổ trợ, ví dụ:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Bạn cũng có thể tạo một jaxlib từ nguồn tại chính thay đổi XLA mà bạn đang tạo (hướng dẫn).
Chúng tôi sẽ sớm bắt đầu hỗ trợ khả năng tương thích ABI.
Bước 2: Sử dụng không gian tên jax_plugins hoặc thiết lập entry_point
Có hai cách để JAX phát hiện trình bổ trợ của bạn.
- Sử dụng các gói không gian tên (tham chiếu). Xác định một mô-đun duy nhất trên toàn hệ thống trong gói không gian tên
jax_plugins
(tức là chỉ cần tạo một thư mụcjax_plugins
và xác định mô-đun của bạn bên dưới thư mục đó). Dưới đây là ví dụ về cấu trúc thư mục:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Sử dụng siêu dữ liệu gói (tham chiếu). Nếu tạo gói thông qua pyproject.toml hoặc setup.py, hãy quảng cáo tên mô-đun trình bổ trợ bằng cách thêm một điểm truy cập trong nhóm
jax_plugins
trỏ đến tên mô-đun đầy đủ. Dưới đây là ví dụ thông qua pyproject.toml hoặc setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Sau đây là ví dụ về cách triển khai openxla-pjrt-plugin bằng cách sử dụng Tuỳ chọn 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Bước 3: Triển khai phương thức initialize()
Bạn cần triển khai phương thức initialize() trong mô-đun python để đăng ký trình bổ trợ, ví dụ:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Vui lòng tham khảo tại đây để biết cách sử dụng xla_bridge.register_plugin
. Phương thức này hiện là một phương thức riêng tư. API công khai sẽ được phát hành trong tương lai.
Bạn có thể chạy dòng bên dưới để xác minh rằng trình bổ trợ đã được đăng ký và báo lỗi nếu không tải được trình bổ trợ.
jax.config.update("jax_platforms", "my_plugin")
JAX có thể có nhiều phần phụ trợ/trình bổ trợ. Có một số cách để đảm bảo trình bổ trợ của bạn được dùng làm phần phụ trợ mặc định:
- Cách 1: chạy
jax.config.update("jax_platforms", "my_plugin")
ở đầu chương trình. - Cách 2: đặt ENV
JAX_PLATFORMS=my_plugin
. - Cách 3: đặt mức độ ưu tiên đủ cao khi gọi xb.register_plugin (giá trị mặc định là 400, cao hơn các phần phụ trợ hiện có khác). Lưu ý rằng phần phụ trợ có mức độ ưu tiên cao nhất sẽ chỉ được sử dụng khi
JAX_PLATFORMS=''
. Giá trị mặc định củaJAX_PLATFORMS
là''
nhưng đôi khi giá trị này sẽ bị ghi đè.
Cách kiểm thử bằng JAX
Một số trường hợp kiểm thử cơ bản để thử:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Chúng tôi sẽ sớm thêm hướng dẫn để chạy kiểm thử đơn vị jax trên trình bổ trợ của bạn!)
Để xem thêm ví dụ về trình bổ trợ PJRT, hãy xem phần Ví dụ về PJRT.