PJRT C++ 裝置 API 總覽

背景

PJRT 是我們希望加入機器學習生態系統的統一裝置 API。長期願景是:

  1. 架構 (JAX、TF 等) 會呼叫 PJRT,而 PJRT 則包含裝置專屬的實作項目,這些項目對架構而言是不可見的;
  2. 每部裝置著重於實作 PJRT API,且對架構無法透明。

PJRT 提供 C API 和 C++ API。插入任一層都沒有問題,C++ API 會使用類別來抽取部分概念,但也對 XLA 資料類型有更高的關聯性。本頁內容著重於 C++ API。

PJRT 元件

PJRT 元件

PjRtClient

如需完整參考資料,請參閱 pjrt_client.h > PjRtClient

用戶端會管理裝置和架構之間的所有通訊,並封裝在通訊中使用的所有狀態。這些類別會提供一組通用 API,用於與 PJRT 外掛程式互動,且擁有特定外掛程式的裝置和記憶體空間。

PjRtDevice

完整參考資料請見 pjrt_client.h > PjRtDevicepjrt_device_description.h

裝置類別用於描述單一裝置。裝置會提供裝置說明,協助識別裝置類型 (用於識別 GPU/CPU/xPU 的專屬雜湊),以及裝置在本機和全球裝置格線中的位置。

裝置也知道其相關聯的記憶體空間,以及該空間的擁有者。

裝置並「不一定」知道與其相關聯的實際資料緩衝區,但可以透過檢視相關聯的記憶體空間來找出答案。

PjRtMemorySpace

如需完整參考資料,請參閱 pjrt_client.h > PjRtMemorySpace

記憶體空間可用來說明記憶體的位置。這些資料可以取消固定,並自由存放於任何位置,但必須透過裝置存取;也可以固定,並存放於特定裝置上。

記憶體空間知道其相關聯的資料緩衝區、與記憶體空間相關聯的裝置 (複數),以及屬於該記憶體的用戶端。

PjRtBuffer

如需完整參考資料,請參閱 pjrt_client.h > PjRtBuffer

緩衝區會以某種格式在裝置上儲存資料,方便在外掛程式中使用,例如 MLIR 元素屬性或專屬張量格式。架構可能會嘗試以 xla::Literal 的形式將資料傳送至裝置,也就是模組的輸入引數,必須複製 (或借用) 至裝置的記憶體。當緩衝區不再需要時,架構會叫用 Delete 方法來清理。

緩衝區會知道自己所屬的記憶體空間,並可透過推導方式找出哪些裝置可以存取,但緩衝區不一定會知道自己的裝置。

為了與架構通訊,緩衝區會瞭解如何轉換 xla::Literal 類型:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

用於建立緩衝區的 API 具有「緩衝區語意」,可協助判斷是否可分享、複製或變更主機緩衝區中的文字資料。

最後,如果緩衝區指派給架構層 x = jit(foo)(10) 中的變數,則緩衝區可能需要比執行範圍更長的時間,在這種情況下,緩衝區可建立外部參照,提供緩衝區所保留資料的暫時擁有指標,以及用於解讀基礎資料的中繼資料 (dtype / dim 大小)。

PjRtCompiler

完整參考資料請見 pjrt_compiler.h > PjRtCompiler

PjRtCompiler 類別可為 XLA 後端提供實用實作詳細資料,但外掛程式不必實作此類別。理論上,PjRtCompilerPjRtClient::Compile 方法的責任是接收輸入模組並傳回 PjRtLoadedExecutable

PjRtExecutable / PjRtLoadedExecutable

完整參考資料位於 pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable

PjRtExecutable 會瞭解如何取得已編譯的構件和執行選項,並將其序列化/反序列化,以便視需要儲存及載入可執行檔。

PjRtLoadedExecutable 是記憶體中編譯的執行檔,可供輸入引數執行,也是 PjRtExecutable 的子類別。

可執行檔會透過用戶端的其中一個 Execute 方法進行介面連結:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

在呼叫 Execute 之前,架構會將所有必要資料轉移至執行中用戶端擁有的 PjRtBuffers,但會傳回架構供參考。這些緩衝區會做為引數提供給 Execute 方法。

PJRT 概念

PjRtFutures 和非同步運算

如果外掛程式的任何部分是異步實作,則必須正確實作 Future。

請考慮使用下列程式:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

非同步外掛程式可將運算 x 排入佇列,並立即傳回尚未準備好供讀取的緩衝區,但執行作業會填入該緩衝區。執行作業可在 x 之後繼續排入必要的運算作業,這些運算作業不需要 x,包括在其他 PJRT 裝置上執行的作業。需要 x 的值後,系統會停止執行作業,直到緩衝區透過 GetReadyFuture 傳回的 Future 宣告自身準備就緒為止。

未來可用來判斷物件 (包括裝置和緩衝區) 何時可供使用。

進階概念

除了實作基本 API 之外,您還可以擴充 JAX 的功能,讓外掛程式可以使用這些功能。這些都是選擇加入的功能,因為在一般 JIT 和執行工作流程中,即使沒有這些功能,也能正常運作。不過,如果是正式品質的管道,則應考慮將這些功能納入 PJRT API 支援的支援程度:

  • 記憶體空間
  • 自訂版面配置
  • 傳送/接收等通訊作業
  • 主機卸載
  • 資料分割

一般 PJRT 架構與裝置通訊

記錄範例

以下是載入 PJRT 外掛程式並執行 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) 時呼叫的方法記錄。在本例中,我們會記錄 JAX 與 StableHLO 參考 PJRT 外掛程式互動。

記錄示例

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)