Thông tin khái quát
PJRT là API Thiết bị đồng nhất mà chúng tôi muốn thêm vào hệ sinh thái học máy. Tầm nhìn dài hạn là:
- Các khung (JAX, TF, v.v.) sẽ gọi PJRT, có các phương thức triển khai dành riêng cho thiết bị không rõ ràng đối với các khung;
- Mỗi thiết bị tập trung vào việc triển khai các API PJRT và có thể không rõ ràng đối với các khung.
PJRT cung cấp cả API C và API C++. Bạn có thể cắm vào một trong hai lớp, API C++ sử dụng các lớp để trừu tượng hoá một số khái niệm, nhưng cũng có mối liên kết chặt chẽ hơn với các loại dữ liệu XLA. Trang này tập trung vào API C++.
Thành phần PJRT
PjRtClient
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtClient
.
Ứng dụng quản lý tất cả hoạt động giao tiếp giữa thiết bị và khung, đồng thời đóng gói tất cả trạng thái được sử dụng trong hoạt động giao tiếp. Các lớp này có một bộ API chung để tương tác với trình bổ trợ PJRT, đồng thời sở hữu các thiết bị và không gian bộ nhớ cho một trình bổ trợ nhất định.
PjRtDevice
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtDevice
và pjrt_device_description.h
Lớp thiết bị dùng để mô tả một thiết bị. Mỗi thiết bị có một nội dung mô tả thiết bị để giúp xác định loại thiết bị (hàm băm duy nhất để xác định GPU/CPU/xPU) và vị trí trong lưới thiết bị cả cục bộ và toàn cầu.
Thiết bị cũng biết không gian bộ nhớ liên kết và ứng dụng sở hữu không gian bộ nhớ đó.
Thiết bị không nhất thiết phải biết vùng đệm của dữ liệu thực tế liên kết với thiết bị, nhưng thiết bị có thể tìm hiểu điều đó bằng cách xem qua các không gian bộ nhớ liên kết.
PjRtMemorySpace
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtMemorySpace
.
Bạn có thể dùng không gian bộ nhớ để mô tả vị trí của bộ nhớ. Bạn có thể bỏ ghim và để các ứng dụng này ở bất kỳ đâu nhưng có thể truy cập được từ một thiết bị, hoặc ghim các ứng dụng này và phải để trên một thiết bị cụ thể.
Không gian bộ nhớ biết vùng đệm dữ liệu được liên kết và các thiết bị (số nhiều) mà không gian bộ nhớ được liên kết, cũng như ứng dụng mà không gian bộ nhớ đó thuộc về.
PjRtBuffer
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtBuffer
.
Vùng đệm lưu trữ dữ liệu trên thiết bị ở một số định dạng dễ xử lý trong trình bổ trợ, chẳng hạn như thuộc tính phần tử MLIR hoặc định dạng tensor độc quyền.
Một khung có thể cố gắng gửi dữ liệu đến thiết bị ở dạng xla::Literal
, nghĩa là đối với đối số đầu vào cho mô-đun, đối số này phải được sao chép (hoặc mượn) vào bộ nhớ của thiết bị. Khi không cần vùng đệm nữa, khung sẽ gọi phương thức Delete
để dọn dẹp.
Vùng đệm biết không gian bộ nhớ mà nó thuộc về và có thể xác định được thiết bị nào có thể truy cập vào vùng đệm đó, nhưng vùng đệm không nhất thiết phải biết thiết bị của mình.
Để giao tiếp với các khung, vùng đệm biết cách chuyển đổi sang và từ loại xla::Literal
:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
API để tạo vùng đệm có Ngữ nghĩa vùng đệm giúp xác định xem dữ liệu cố định từ vùng đệm máy chủ có thể được chia sẻ, sao chép hoặc thay đổi hay không.
Cuối cùng, vùng đệm có thể cần kéo dài hơn phạm vi thực thi của vùng đệm, nếu vùng đệm được gán cho một biến trong lớp khung x = jit(foo)(10)
, trong những trường hợp này, vùng đệm cho phép tạo các tệp tham chiếu bên ngoài cung cấp con trỏ tạm thời sở hữu dữ liệu do vùng đệm lưu giữ, cùng với siêu dữ liệu (kích thước dtype / dim) để diễn giải dữ liệu cơ bản.
PjRtCompiler
Thông tin tham khảo đầy đủ tại pjrt_compiler.h > PjRtCompiler
.
Lớp PjRtCompiler
cung cấp thông tin triển khai hữu ích cho phần phụ trợ XLA, nhưng không cần thiết phải triển khai trình bổ trợ. Về lý thuyết, trách nhiệm của PjRtCompiler
hoặc phương thức PjRtClient::Compile
là lấy một mô-đun đầu vào và trả về PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
Thông tin tham khảo đầy đủ tại pjrt_executable.h > PjRtExecutable
và pjrt_client.h > PjRtLoadedExecutable
.
PjRtExecutable
biết cách lấy một cấu phần phần mềm được biên dịch và các tuỳ chọn thực thi, đồng thời chuyển đổi tuần tự/huỷ chuyển đổi tuần tự các cấu phần phần mềm đó để có thể lưu trữ và tải tệp thực thi khi cần.
PjRtLoadedExecutable
là tệp thực thi được biên dịch trong bộ nhớ, sẵn sàng để các đối số đầu vào thực thi, đây là một lớp con của PjRtExecutable
.
Các tệp thực thi được giao tiếp thông qua một trong các phương thức Execute
của ứng dụng:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Trước khi gọi Execute
, khung sẽ chuyển tất cả dữ liệu bắt buộc đến PjRtBuffers
do ứng dụng thực thi sở hữu, nhưng được trả về để khung tham chiếu. Sau đó, các vùng đệm này được cung cấp dưới dạng đối số cho phương thức Execute
.
Khái niệm về PJRT
PjRtFutures và phép tính không đồng bộ
Nếu bất kỳ phần nào của trình bổ trợ được triển khai không đồng bộ, thì phần đó phải triển khai đúng cách các futures.
Hãy xem xét chương trình sau:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Trình bổ trợ không đồng bộ có thể thêm x
tính toán vào hàng đợi và ngay lập tức trả về một bộ đệm chưa sẵn sàng để đọc, nhưng quá trình thực thi sẽ điền sẵn bộ đệm đó. Quá trình thực thi có thể tiếp tục thêm các phép tính cần thiết vào hàng đợi sau x
mà không yêu cầu x
, bao gồm cả việc thực thi trên các thiết bị PJRT khác. Khi cần giá trị của x
, quá trình thực thi sẽ bị chặn cho đến khi vùng đệm tự khai báo sẵn sàng thông qua giá trị trong tương lai do GetReadyFuture
trả về.
Futures có thể hữu ích để xác định thời điểm một đối tượng có sẵn, bao gồm cả thiết bị và vùng đệm.
Khái niệm nâng cao
Việc mở rộng phạm vi triển khai các API cơ sở sẽ mở rộng các tính năng của JAX mà trình bổ trợ có thể sử dụng. Đây đều là các tính năng không bắt buộc sử dụng, tức là quy trình làm việc JIT và thực thi thông thường sẽ hoạt động mà không cần các tính năng này, nhưng đối với quy trình chất lượng sản xuất, bạn nên cân nhắc mức độ hỗ trợ cho bất kỳ tính năng nào trong số các tính năng này do API PJRT hỗ trợ:
- Không gian bộ nhớ
- Bố cục tuỳ chỉnh
- Các thao tác liên lạc như gửi/nhận
- Giảm tải máy chủ
- Phân đoạn
Giao tiếp khung PJRT-thiết bị thông thường
Nhật ký mẫu
Sau đây là nhật ký của các phương thức được gọi để tải trình bổ trợ PJRT và thực thi y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. Trong trường hợp này, chúng ta ghi lại JAX tương tác với trình bổ trợ PJRT tham chiếu StableHLO.
Nhật ký mẫu
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)