Tổng quan về API thiết bị PJRT C++

Thông tin khái quát

PJRT là API Thiết bị đồng nhất mà chúng tôi muốn thêm vào hệ sinh thái học máy. Tầm nhìn dài hạn là:

  1. Các khung (JAX, TF, v.v.) sẽ gọi PJRT, có các phương thức triển khai dành riêng cho thiết bị không rõ ràng đối với các khung;
  2. Mỗi thiết bị tập trung vào việc triển khai các API PJRT và có thể không rõ ràng đối với các khung.

PJRT cung cấp cả API C và API C++. Bạn có thể cắm vào một trong hai lớp, API C++ sử dụng các lớp để trừu tượng hoá một số khái niệm, nhưng cũng có mối liên kết chặt chẽ hơn với các loại dữ liệu XLA. Trang này tập trung vào API C++.

Thành phần PJRT

Các thành phần PJRT

PjRtClient

Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtClient.

Ứng dụng quản lý tất cả hoạt động giao tiếp giữa thiết bị và khung, đồng thời đóng gói tất cả trạng thái được sử dụng trong hoạt động giao tiếp. Các lớp này có một bộ API chung để tương tác với trình bổ trợ PJRT, đồng thời sở hữu các thiết bị và không gian bộ nhớ cho một trình bổ trợ nhất định.

PjRtDevice

Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtDevicepjrt_device_description.h

Lớp thiết bị được dùng để mô tả một thiết bị. Mỗi thiết bị có nội dung mô tả về thiết bị để giúp xác định loại thiết bị (hàm băm duy nhất để xác định GPU/CPU/xPU) và vị trí trong một lưới thiết bị cả cục bộ và trên toàn cục.

Thiết bị cũng biết không gian bộ nhớ liên kết và ứng dụng sở hữu không gian bộ nhớ đó.

Thiết bị không nhất thiết phải biết vùng đệm của dữ liệu thực tế liên kết với thiết bị, nhưng thiết bị có thể tìm hiểu điều đó bằng cách xem qua các không gian bộ nhớ liên kết.

PjRtMemorySpace

Thông tin tham khảo đầy đủ tại pjrt_client.h > PjRtMemorySpace.

Bạn có thể dùng không gian bộ nhớ để mô tả vị trí của bộ nhớ. Bạn có thể bỏ ghim và để các ứng dụng này ở bất kỳ đâu nhưng có thể truy cập được từ một thiết bị, hoặc ghim các ứng dụng này và phải để trên một thiết bị cụ thể.

Không gian bộ nhớ biết vùng đệm dữ liệu được liên kết và các thiết bị (số nhiều) mà không gian bộ nhớ được liên kết, cũng như ứng dụng mà không gian bộ nhớ đó thuộc về.

PjRtBuffer

Tài liệu tham khảo đầy đủ tại pjrt_client.h > PjRtBuffer.

Bộ đệm lưu giữ dữ liệu trên thiết bị ở một số định dạng mà bạn có thể dễ dàng thao tác bên trong trình bổ trợ, chẳng hạn như thuộc tính phần tử MLIR hoặc định dạng tensor độc quyền. Một khung có thể cố gắng gửi dữ liệu đến thiết bị ở dạng xla::Literal, nghĩa là đối với đối số đầu vào cho mô-đun, đối số này phải được sao chép (hoặc mượn) vào bộ nhớ của thiết bị. Khi không cần vùng đệm nữa, khung sẽ gọi phương thức Delete để dọn dẹp.

Vùng đệm biết không gian bộ nhớ mà nó thuộc về và có thể xác định được thiết bị nào có thể truy cập vào vùng đệm đó, nhưng vùng đệm không nhất thiết phải biết thiết bị của mình.

Để giao tiếp với các khung, vùng đệm biết cách chuyển đổi sang và từ loại xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Các API để tạo vùng đệm có Ngữ nghĩa vùng đệm giúp cho biết liệu có thể chia sẻ hoặc sao chép hoặc thay đổi dữ liệu cố định từ vùng đệm máy chủ hay không.

Cuối cùng, vùng đệm có thể cần kéo dài hơn phạm vi thực thi của vùng đệm, nếu vùng đệm được gán cho một biến trong lớp khung x = jit(foo)(10), trong những trường hợp này, vùng đệm cho phép tạo các tệp tham chiếu bên ngoài cung cấp con trỏ tạm thời sở hữu dữ liệu do vùng đệm lưu giữ, cùng với siêu dữ liệu (kích thước dtype / dim) để diễn giải dữ liệu cơ bản.

PjRtCompiler

Thông tin tham khảo đầy đủ tại pjrt_compiler.h > PjRtCompiler.

Lớp PjRtCompiler cung cấp thông tin triển khai hữu ích cho phần phụ trợ XLA, nhưng không cần thiết phải triển khai trình bổ trợ. Theo lý thuyết, trách nhiệm của PjRtCompiler hoặc phương thức PjRtClient::Compile là lấy mô-đun đầu vào và trả về một PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Thông tin tham khảo đầy đủ tại pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable biết cách lấy một cấu phần phần mềm được biên dịch và các tuỳ chọn thực thi, đồng thời chuyển đổi tuần tự/huỷ chuyển đổi tuần tự các cấu phần phần mềm đó để có thể lưu trữ và tải tệp thực thi khi cần.

PjRtLoadedExecutable là tệp thực thi được biên dịch trong bộ nhớ, sẵn sàng để các đối số đầu vào thực thi, đây là một lớp con của PjRtExecutable.

Các tệp thực thi được giao tiếp thông qua một trong các phương thức Execute của ứng dụng:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Trước khi gọi Execute, khung sẽ chuyển tất cả dữ liệu bắt buộc đến PjRtBuffers do ứng dụng thực thi sở hữu, nhưng được trả về để khung tham chiếu. Sau đó, các vùng đệm này được cung cấp dưới dạng đối số cho phương thức Execute.

Khái niệm về PJRT

PjRtFutures và phép tính không đồng bộ

Nếu bất kỳ phần nào của trình bổ trợ được triển khai không đồng bộ, thì phần đó phải triển khai đúng cách các futures.

Hãy xem xét chương trình sau:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Trình bổ trợ không đồng bộ có thể thêm x tính toán vào hàng đợi và ngay lập tức trả về một bộ đệm chưa sẵn sàng để đọc, nhưng quá trình thực thi sẽ điền sẵn bộ đệm đó. Quá trình thực thi có thể tiếp tục thêm các phép tính cần thiết vào hàng đợi sau x mà không cần x, bao gồm cả quá trình thực thi trên các thiết bị PJRT khác. Khi cần giá trị của x, quá trình thực thi sẽ bị chặn cho đến khi vùng đệm tự khai báo sẵn sàng thông qua giá trị trong tương lai do GetReadyFuture trả về.

Futures có thể hữu ích để xác định thời điểm một đối tượng có sẵn, bao gồm cả thiết bị và vùng đệm.

Khái niệm nâng cao

Việc mở rộng ngoài việc triển khai các API cơ sở sẽ mở rộng các tính năng của JAX mà một trình bổ trợ có thể sử dụng. Đây đều là các tính năng không bắt buộc sử dụng, tức là quy trình làm việc JIT và thực thi thông thường sẽ hoạt động mà không cần các tính năng này, nhưng đối với quy trình chất lượng sản xuất, bạn nên cân nhắc mức độ hỗ trợ cho bất kỳ tính năng nào trong số các tính năng này do API PJRT hỗ trợ:

  • Không gian bộ nhớ
  • Bố cục tuỳ chỉnh
  • Các thao tác liên lạc như gửi/nhận
  • Giảm tải máy chủ
  • Phân đoạn

Giao tiếp khung PJRT-thiết bị thông thường

Nhật ký mẫu

Sau đây là nhật ký của các phương thức được gọi để tải trình bổ trợ PJRT và thực thi y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Trong trường hợp này, chúng ta ghi lại JAX tương tác với trình bổ trợ PJRT tham chiếu StableHLO.

Nhật ký mẫu

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)