背景
PJRT 是我們想新增至機器學習生態系統的統一裝置 API。長期願景是:
- 架構 (JAX、TF 等) 會呼叫 PJRT,後者具有架構不透明的裝置專屬實作項目;
- 每部裝置都會專注於實作 PJRT API,且對架構而言可能是不透明的。
PJRT 提供 C API 和 C++ API。在任一層插入都可以,C++ API 會使用類別來抽象化某些概念,但與 XLA 資料型別的關聯也更緊密。本頁內容著重於 C++ API。
PJRT 元件
PjRtClient
完整參考資料請見 pjrt_client.h > PjRtClient。
用戶端會管理裝置和架構間的所有通訊,並封裝通訊中使用的所有狀態。它們有一組通用 API,可與 PJRT 外掛程式互動,並擁有特定外掛程式的裝置和記憶體空間。
PjRtDevice
完整參考資料請見 pjrt_client.h > PjRtDevice 和 pjrt_device_description.h
裝置類別用於描述單一裝置。裝置會提供裝置說明,協助識別裝置類型 (用於識別 GPU/CPU/xPU 的專屬雜湊),以及裝置在本地和全球裝置格線中的位置。
裝置也會知道相關聯的記憶體空間和所屬用戶端。
裝置不一定知道與其相關聯的實際資料緩衝區,但可以透過查看相關聯的記憶體空間來找出緩衝區。
PjRtMemorySpace
完整參考資料請見 pjrt_client.h > PjRtMemorySpace。
記憶體空間可用於描述記憶體位置。這些檔案可以取消釘選,並自由儲存在任何位置,但只能透過裝置存取;也可以釘選,但必須儲存在特定裝置上。
記憶體空間會知道相關聯的資料緩衝區、與記憶體空間相關聯的裝置 (複數),以及所屬的用戶端。
PjRtBuffer
完整參考資料請見 pjrt_client.h > PjRtBuffer。
緩衝區會以某種格式在裝置上保存資料,方便在外掛程式內使用,例如 MLIR 元素屬性或專屬張量格式。架構可能會嘗試以 xla::Literal 形式將資料傳送至裝置,也就是模組的輸入引數,必須複製 (或借用) 至裝置的記憶體。架構會呼叫 Delete 方法進行清理,因為不再需要緩衝區。
緩衝區知道自己屬於哪個記憶體空間,並可遞移地判斷哪些裝置能夠存取該空間,但緩衝區不一定知道自己的裝置。
如要與架構通訊,緩衝區必須知道如何轉換為 xla::Literal 型別,以及如何從該型別轉換:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
建立緩衝區的 API 具有 緩衝區語意,可協助判斷是否可共用、複製或變動主機緩衝區的常值資料。
最後,如果緩衝區指派給架構層 x = jit(foo)(10) 中的變數,緩衝區可能需要比執行範圍更長的時間。在這些情況下,緩衝區可建立外部參照,提供緩衝區所保存資料的暫時擁有指標,以及用於解讀基礎資料的中繼資料 (dtype / dim 大小)。
PjRtCompiler
完整參考資料請見 pjrt_compiler.h > PjRtCompiler。
PjRtCompiler 類別提供 XLA 後端的實用實作詳細資料,但外掛程式不一定要實作。從理論上來說,PjRtCompiler 或 PjRtClient::Compile 方法的責任是接收輸入模組,並傳回 PjRtLoadedExecutable。
PjRtExecutable / PjRtLoadedExecutable
完整參考資料請見 pjrt_executable.h > PjRtExecutable 和 pjrt_client.h > PjRtLoadedExecutable。
PjRtExecutable 知道如何取得已編譯的構件和執行選項,並將其序列化/還原序列化,以便儲存可執行檔並視需要載入。
PjRtLoadedExecutable 是記憶體內編譯的執行檔,可供輸入引數執行,是 PjRtExecutable 的子類別。
可執行檔會透過用戶端的其中一種 Execute 方法進行介面操作:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
在呼叫 Execute 之前,架構會將所有必要資料轉移至執行用戶端擁有的 PjRtBuffers,但會傳回供架構參照。這些緩衝區隨即會以引數形式提供給 Execute 方法。
PJRT 概念
Future 和非同步運算
如果外掛程式的任何部分是以非同步方式實作,則必須正確實作 Future。
請參考下列程式:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
非同步外掛程式可以將運算 x 排入佇列,並立即傳回尚未準備好讀取的緩衝區,但執行作業會填入該緩衝區。執行作業可繼續在 x 後將必要運算排入佇列,這些運算不需要 x,包括在其他 PJRT 裝置上執行作業。需要 x 值時,執行作業會遭到封鎖,直到緩衝區透過 GetReadyFuture 傳回的 Future 宣告本身已準備就緒為止。
Future 可用於判斷物件何時可用,包括裝置和緩衝區。
進階概念
除了實作基本 API 之外,擴充功能還能擴展外掛程式可使用的 JAX 功能。從某種意義上來說,這些都是選擇性加入的功能,因為典型的 JIT 和執行工作流程不需要這些功能也能運作,但對於生產品質的管道,可能需要考慮對 PJRT API 支援的任何這些功能提供支援的程度:
- 記憶體空間
- 自訂版面配置
- 傳送/接收等通訊作業
- 主機卸載
- 資料分割
典型的 PJRT 架構裝置通訊
記錄範例
以下是載入 PJRT 外掛程式及執行 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) 時呼叫的方法記錄。在本例中,我們會記錄 JAX 與 StableHLO 參考 PJRT 外掛程式互動的情形。
記錄檔範例
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)