Thông tin khái quát
PJRT là API thiết bị thống nhất mà chúng tôi muốn thêm vào hệ sinh thái học máy. Tầm nhìn dài hạn là: (1) các khung (JAX, TF, v.v.) sẽ gọi PJRT, có các triển khai dành riêng cho thiết bị được che khuất đối với khung; (2) mỗi thiết bị tập trung vào việc triển khai các API PJRT và có thể không rõ ràng đối với các khung.
Tài liệu này tập trung vào các đề xuất về cách tích hợp với PJRT và cách kiểm thử việc tích hợp PJRT với JAX.
Cách tích hợp với PJRT
Bước 1: Triển khai giao diện PJRT C API
Cách A: Bạn có thể triển khai trực tiếp PJRT C API.
Lựa chọn B: Nếu có thể tạo bản dựng dựa trên mã C++ trong kho lưu trữ xla (thông qua forking hoặc bazel), thì bạn cũng có thể triển khai API PJRT C++ và sử dụng trình bao bọc C→C++:
- Triển khai một ứng dụng PJRT C++ kế thừa từ ứng dụng PJRT cơ sở (và các lớp PJRT có liên quan). Dưới đây là một số ví dụ về ứng dụng PJRT C++: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- Triển khai một số phương thức API C không thuộc ứng dụng C++ PJRT:
- PJRT_Client_Create (Tạo). Dưới đây là một số mã giả mẫu (giả sử
GetPluginPjRtClient
trả về một ứng dụng C++ PJRT được triển khai ở trên):
- PJRT_Client_Create (Tạo). Dưới đây là một số mã giả mẫu (giả sử
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
Lưu ý: PJRT_Client_Create có thể lấy các tuỳ chọn được chuyển từ khung. Đây là ví dụ về cách ứng dụng GPU sử dụng tính năng này.
- [Không bắt buộc] PJRT_TopologyDescription_Create.
- [Không bắt buộc] PJRT_Plugin_Initialize. Đây là quá trình thiết lập trình bổ trợ một lần, khung sẽ được gọi bởi khung trước khi gọi bất kỳ hàm nào khác.
- [Không bắt buộc] PJRT_Plugin_Attributes.
Với trình bao bọc, bạn không cần triển khai các API C còn lại.
Bước 2: Triển khai GetPjRtApi
Bạn cần triển khai một phương thức GetPjRtApi
để trả về một PJRT_Api*
chứa các con trỏ hàm cho các hoạt động triển khai API PJRT C. Dưới đây là một ví dụ giả định việc triển khai thông qua trình bao bọc (tương tự như pjrt_c_api_cpu.cc):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
Bước 3: Kiểm thử việc triển khai API C
Bạn có thể gọi RegisterPjRtCApiTestFactory để chạy một tập hợp nhỏ các bài kiểm thử cho các hành vi cơ bản của API PJRT C.
Cách sử dụng trình bổ trợ PJRT của JAX
Bước 1: Thiết lập JAX
Bạn có thể sử dụng JAX mỗi đêm
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
hoặc tạo JAX từ nguồn.
Hiện tại, bạn cần so khớp phiên bản jaxlib với phiên bản PJRT C API. Thông thường, bạn chỉ nên sử dụng phiên bản jaxlib ban đêm từ cùng ngày với cam kết TF mà bạn đang xây dựng trình bổ trợ, ví dụ:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Bạn cũng có thể tạo jaxlib từ nguồn theo đúng cam kết XLA mà bạn đang xây dựng (instructions).
Chúng tôi sẽ sớm bắt đầu hỗ trợ khả năng tương thích với ABI.
Bước 2: Sử dụng không gian tên jax_plugins hoặc thiết lập enter_point
Có hai cách để JAX phát hiện trình bổ trợ của bạn.
- Sử dụng gói không gian tên (ref). Định nghĩa một mô-đun duy nhất trên toàn hệ thống trong gói không gian tên
jax_plugins
(tức là chỉ cần tạo một thư mụcjax_plugins
và xác định mô-đun bên dưới mô-đun đó). Dưới đây là ví dụ về cấu trúc thư mục:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Sử dụng siêu dữ liệu của gói (ref). Nếu xây dựng một gói thông qua pyproject.toml hoặc setup.py, hãy quảng cáo tên mô-đun trình bổ trợ của bạn bằng cách đưa một điểm truy cập vào trong nhóm
jax_plugins
trỏ đến tên mô-đun đầy đủ của bạn. Dưới đây là ví dụ thông qua pyproject.toml hoặc setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
Dưới đây là các ví dụ về cách triển khai openxla-pjrt-plugin bằng cách sử dụng Cách 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
Bước 3: Triển khai phương thức launch()
Bạn cần triển khai phương thứcinitialize() trong mô-đun python của mình để đăng ký trình bổ trợ, ví dụ:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
Vui lòng tham khảo tại đây để biết cách sử dụng xla_bridge.register_plugin
. Hiện tại, đây là một phương thức riêng tư. Một API công khai sẽ được phát hành trong tương lai.
Bạn có thể chạy dòng bên dưới để xác minh rằng trình bổ trợ đã được đăng ký và báo lỗi nếu không tải được.
jax.config.update("jax_platforms", "my_plugin")
JAX có thể có nhiều phần phụ trợ/trình bổ trợ. Có một số cách giúp đảm bảo trình bổ trợ của bạn được dùng làm phần phụ trợ mặc định:
- Lựa chọn 1: chạy
jax.config.update("jax_platforms", "my_plugin")
ở đầu chương trình. - Cách 2: đặt ENV
JAX_PLATFORMS=my_plugin
. - Lựa chọn 3: đặt mức độ ưu tiên đủ cao khi gọi xb.register_plugin (giá trị mặc định là 400, cao hơn các phần phụ trợ hiện có khác). Xin lưu ý rằng phần phụ trợ có mức độ ưu tiên cao nhất sẽ chỉ được dùng khi
JAX_PLATFORMS=''
. Giá trị mặc định củaJAX_PLATFORMS
là''
nhưng đôi khi sẽ bị ghi đè.
Cách kiểm thử bằng JAX
Một số trường hợp kiểm thử cơ bản nên thử:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Chúng tôi sẽ sớm bổ sung hướng dẫn để chạy kiểm thử đơn vị jax đối với trình bổ trợ của bạn!)
Ví dụ: Trình bổ trợ JAX CUDA
- Triển khai API PJRT C thông qua trình bao bọc (pjrt_c_api_gpu.h).
- Thiết lập điểm truy cập cho gói (setup.py).
- Triển khai phương thức launch() (__init__.py).
- Có thể kiểm thử với bất kỳ kiểm thử jax nào cho CUDA. ```