Thông tin khái quát
PJRT là Device API đồng nhất mà chúng tôi muốn thêm vào hệ sinh thái ML. Tầm nhìn dài hạn là:
- Các khung (JAX, TF, v.v.) sẽ gọi PJRT, có các triển khai dành riêng cho thiết bị mà các khung không thể truy cập;
- Mỗi thiết bị tập trung vào việc triển khai các API PJRT và có thể không rõ ràng đối với các khung.
PJRT cung cấp cả API C và API C++. Bạn có thể cắm ở bất kỳ lớp nào, API C++ sử dụng các lớp để trừu tượng hoá một số khái niệm, nhưng cũng có mối liên kết chặt chẽ hơn với kiểu dữ liệu XLA. Trang này tập trung vào API C++.
Thành phần PJRT
PjRtClient
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtClient.
Các ứng dụng quản lý mọi hoạt động giao tiếp giữa thiết bị và khung, đồng thời đóng gói tất cả trạng thái được dùng trong hoạt động giao tiếp. Chúng có một nhóm API chung để tương tác với một trình bổ trợ PJRT, đồng thời sở hữu các thiết bị và không gian bộ nhớ cho một trình bổ trợ nhất định.
PjRtDevice
Tài liệu tham khảo đầy đủ tại pjrt_client.h > PjRtDevice và pjrt_device_description.h
Lớp thiết bị được dùng để mô tả một thiết bị duy nhất. Thiết bị có một phần mô tả thiết bị để giúp xác định loại thiết bị (hàm băm duy nhất để xác định GPU/CPU/xPU) và vị trí trong một lưới thiết bị cả cục bộ và trên toàn cầu.
Các thiết bị cũng biết không gian bộ nhớ được liên kết và ứng dụng mà thiết bị thuộc về.
Thiết bị không nhất thiết phải biết các vùng đệm của dữ liệu thực tế được liên kết với thiết bị đó, nhưng thiết bị có thể tìm ra bằng cách xem xét các không gian bộ nhớ được liên kết.
PjRtMemorySpace
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtMemorySpace.
Bạn có thể dùng không gian bộ nhớ để mô tả vị trí của bộ nhớ. Các ứng dụng này có thể không được ghim và có thể nằm ở bất kỳ vị trí nào nhưng có thể truy cập được từ một thiết bị, hoặc chúng có thể được ghim và phải nằm trên một thiết bị cụ thể.
Các không gian bộ nhớ biết vùng đệm dữ liệu được liên kết và các thiết bị (số nhiều) mà một không gian bộ nhớ được liên kết, cũng như ứng dụng mà không gian bộ nhớ là một phần của.
PjRtBuffer
Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtBuffer.
Vùng đệm lưu trữ dữ liệu trên thiết bị ở một số định dạng mà bạn có thể dễ dàng làm việc trong trình bổ trợ, chẳng hạn như thuộc tính phần tử MLIR hoặc định dạng tensor độc quyền.
Khung có thể cố gắng gửi dữ liệu đến một thiết bị ở dạng xla::Literal, tức là đối với một đối số đầu vào cho mô-đun, đối số này phải được sao chép (hoặc mượn) vào bộ nhớ của thiết bị. Khi không cần đến vùng đệm nữa, khung sẽ gọi phương thức Delete để dọn dẹp.
Vùng đệm biết không gian bộ nhớ mà nó là một phần của, và có thể gián tiếp tìm ra những thiết bị có thể truy cập vào vùng đệm đó, nhưng vùng đệm không nhất thiết phải biết các thiết bị của chúng.
Để giao tiếp với các khung, các vùng đệm biết cách chuyển đổi thành và từ loại xla::Literal:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Các API để tạo vùng đệm có Ngữ nghĩa vùng đệm giúp xác định xem dữ liệu theo nghĩa đen từ vùng đệm máy chủ có thể được chia sẻ, sao chép hay thay đổi hay không.
Cuối cùng, vùng đệm có thể cần tồn tại lâu hơn phạm vi thực thi của nó, nếu được chỉ định cho một biến trong lớp khung x = jit(foo)(10), trong những trường hợp này, vùng đệm cho phép tạo các tham chiếu bên ngoài cung cấp con trỏ tạm thời thuộc sở hữu của dữ liệu do vùng đệm nắm giữ, cùng với siêu dữ liệu (dtype / kích thước dim) để diễn giải dữ liệu cơ bản.
PjRtCompiler
Thông tin tham khảo đầy đủ tại pjrt_compiler.h > PjRtCompiler.
Lớp PjRtCompiler cung cấp thông tin triển khai hữu ích cho các chương trình phụ trợ XLA, nhưng không cần thiết để một trình bổ trợ triển khai. Về lý thuyết, trách nhiệm của PjRtCompiler hoặc phương thức PjRtClient::Compile là lấy một mô-đun đầu vào và trả về PjRtLoadedExecutable.
PjRtExecutable / PjRtLoadedExecutable
Tài liệu tham khảo đầy đủ tại pjrt_executable.h > PjRtExecutable và pjrt_client.h > PjRtLoadedExecutable.
PjRtExecutable biết cách lấy một cấu phần phần mềm đã biên dịch và các lựa chọn thực thi, đồng thời chuyển đổi/huỷ chuyển đổi các lựa chọn đó để có thể lưu trữ và tải một tệp thực thi khi cần.
PjRtLoadedExecutable là tệp thực thi đã biên dịch trong bộ nhớ, sẵn sàng cho các đối số đầu vào để thực thi, đây là một lớp con của PjRtExecutable.
Các tệp thực thi được kết nối thông qua một trong các phương thức Execute của ứng dụng:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Trước khi gọi Execute, khung sẽ chuyển tất cả dữ liệu bắt buộc đến PjRtBuffers do ứng dụng khách đang thực thi sở hữu, nhưng được trả về để khung tham chiếu. Sau đó, các vùng đệm này được cung cấp dưới dạng đối số cho phương thức Execute.
Các khái niệm về PJRT
Hợp đồng tương lai và tính toán không đồng bộ
Nếu bất kỳ phần nào của một trình bổ trợ được triển khai không đồng bộ, thì phần đó phải triển khai đúng cách các đối tượng tương lai.
Hãy xem xét chương trình sau:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Một trình bổ trợ không đồng bộ sẽ có thể xếp hàng đợi cho phép tính x và ngay lập tức trả về một bộ đệm chưa sẵn sàng để đọc, nhưng quá trình thực thi sẽ điền sẵn bộ đệm đó. Quá trình thực thi có thể tiếp tục xếp hàng các phép tính cần thiết sau x, không yêu cầu x, bao gồm cả việc thực thi trên các thiết bị PJRT khác. Khi cần giá trị của x, quá trình thực thi sẽ bị chặn cho đến khi vùng đệm tự khai báo là đã sẵn sàng thông qua đối tượng tương lai do GetReadyFuture trả về.
Các đối tượng Future có thể hữu ích trong việc xác định thời điểm một đối tượng trở nên có sẵn, bao gồm cả thiết bị và vùng đệm.
Khái niệm nâng cao
Ngoài việc triển khai các API cơ sở, việc mở rộng sẽ mở rộng các tính năng của JAX mà một trình bổ trợ có thể sử dụng. Đây đều là những tính năng chọn sử dụng theo nghĩa là quy trình JIT và thực thi thông thường sẽ hoạt động mà không cần những tính năng này, nhưng đối với quy trình chất lượng sản xuất, bạn nên cân nhắc mức độ hỗ trợ cho bất kỳ tính năng nào trong số này do API PJRT hỗ trợ:
- Không gian lưu trữ kỷ niệm
- Bố cục tuỳ chỉnh
- Các thao tác giao tiếp như gửi/nhận
- Tải xuống từ máy chủ
- Phân đoạn
Giao tiếp điển hình giữa thiết bị và khung PJRT
Ví dụ về nhật ký
Sau đây là nhật ký các phương thức được gọi để tải trình bổ trợ PJRT và thực thi y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Trong trường hợp này, chúng ta sẽ ghi lại JAX tương tác với trình bổ trợ PJRT tham chiếu StableHLO.
Ví dụ về nhật ký
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)