Ringkasan PJRT C++ Device API

Latar belakang

PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjangnya adalah:

  1. Framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang tidak transparan untuk framework;
  2. Setiap perangkat berfokus pada penerapan PJRT API, dan dapat buram untuk framework.

PJRT menawarkan C API dan C++ API. Menyambungkan di salah satu lapisan tidak masalah, API C++ menggunakan class untuk memisahkan beberapa konsep, tetapi juga memiliki hubungan yang lebih kuat dengan jenis data XLA. Halaman ini berfokus pada C++ API.

Komponen PJRT

Komponen PJRT

PjRtClient

Referensi lengkap di pjrt_client.h > PjRtClient.

Klien mengelola semua komunikasi antara perangkat dan framework, serta mengaitkan semua status yang digunakan dalam komunikasi. Plugin ini memiliki kumpulan API umum untuk berinteraksi dengan plugin PJRT, dan memiliki perangkat serta ruang memori untuk plugin tertentu.

PjRtDevice

Referensi lengkap di pjrt_client.h > PjRtDevice, dan pjrt_device_description.h

Class perangkat digunakan untuk mendeskripsikan satu perangkat. Perangkat memiliki deskripsi perangkat untuk membantu mengidentifikasi jenisnya (hash unik untuk mengidentifikasi GPU/CPU/xPU), dan lokasi dalam petak perangkat secara lokal dan global.

Perangkat juga mengetahui ruang memori terkait dan klien yang memilikinya.

Perangkat tidak harus mengetahui buffering data sebenarnya yang terkait dengan perangkat, tetapi perangkat dapat mengetahuinya dengan melihat ruang memori terkait.

PjRtMemorySpace

Referensi lengkap di pjrt_client.h > PjRtMemorySpace.

Ruang memori dapat digunakan untuk menjelaskan lokasi memori. Aplikasi ini dapat di-unpin, dan dapat ditayangkan di mana saja, tetapi dapat diakses dari perangkat, atau dapat disematkan dan harus ditayangkan di perangkat tertentu.

Ruang memori mengetahui buffer data terkait, dan perangkat (jamak) yang terkait dengan ruang memori, serta klien yang menjadi bagiannya.

PjRtBuffer

Referensi lengkap di pjrt_client.h > PjRtBuffer.

Buffer menyimpan data di perangkat dalam beberapa format yang akan mudah digunakan di dalam plugin, seperti atribut elemen MLIR atau format tensor eksklusif. Framework dapat mencoba mengirim data ke perangkat dalam bentuk xla::Literal, yaitu untuk argumen input ke modul, yang harus di-clone (atau dipinjam), ke memori perangkat. Setelah buffer tidak lagi diperlukan, metode Delete akan dipanggil oleh framework untuk membersihkan.

Buffer mengetahui ruang memori yang menjadi bagiannya, dan secara transitif dapat mengetahui perangkat mana yang dapat mengaksesnya, tetapi buffer tidak selalu mengetahui perangkatnya.

Untuk berkomunikasi dengan framework, buffering mengetahui cara mengonversi ke dan dari jenis xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API untuk membuat buffering memiliki Semantik Buffer yang membantu menentukan apakah data literal dari buffering host dapat dibagikan atau disalin atau diubah.

Terakhir, buffering mungkin perlu bertahan lebih lama dari cakupan eksekusinya, jika ditetapkan ke variabel di lapisan framework x = jit(foo)(10), dalam hal ini buffer memungkinkan pembuatan referensi eksternal yang memberikan pointer yang dimiliki sementara ke data yang disimpan oleh buffering, bersama dengan metadata (ukuran dtype / dim) untuk menafsirkan data pokok.

PjRtCompiler

Referensi lengkap di pjrt_compiler.h > PjRtCompiler.

Class PjRtCompiler memberikan detail implementasi yang berguna untuk backend XLA, tetapi tidak diperlukan untuk implementasi plugin. Secara teori, tanggung jawab PjRtCompiler, atau metode PjRtClient::Compile, adalah mengambil modul input dan menampilkan PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Referensi lengkap di pjrt_executable.h > PjRtExecutable, dan pjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable mengetahui cara mengambil artefak yang dikompilasi dan opsi eksekusi serta melakukan serialisasi/deserialisasi sehingga file yang dapat dieksekusi dapat disimpan dan dimuat sesuai kebutuhan.

PjRtLoadedExecutable adalah file yang dapat dieksekusi yang dikompilasi dalam memori dan siap untuk dieksekusi argumen input, yang merupakan subclass dari PjRtExecutable.

File yang dapat dieksekusi dihubungkan melalui salah satu metode Execute klien:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Sebelum memanggil Execute, framework akan mentransfer semua data yang diperlukan ke PjRtBuffers yang dimiliki oleh klien yang menjalankan, tetapi ditampilkan untuk direferensikan oleh framework. Buffer ini kemudian diberikan sebagai argumen ke metode Execute.

Konsep PJRT

PjRtFutures & Komputasi Asinkron

Jika ada bagian plugin yang diterapkan secara asinkron, bagian tersebut harus menerapkan future dengan benar.

Pertimbangkan program berikut:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Plugin asinkron akan dapat mengantrekan komputasi x, dan langsung menampilkan buffer yang belum siap dibaca, tetapi eksekusi akan mengisinya. Eksekusi dapat terus mengantrekan komputasi yang diperlukan setelah x, yang tidak memerlukan x, termasuk eksekusi di perangkat PJRT lainnya. Setelah nilai x diperlukan, eksekusi akan diblokir hingga buffering menyatakan dirinya siap melalui masa mendatang yang ditampilkan oleh GetReadyFuture.

Futures dapat berguna untuk menentukan kapan objek tersedia, termasuk perangkat dan buffering.

Konsep lanjutan

Memperluas selain menerapkan API dasar akan memperluas fitur JAX yang dapat digunakan oleh plugin. Semua ini adalah fitur keikutsertaan dalam arti bahwa pada JIT dan alur kerja eksekusi standar akan berfungsi tanpanya, tetapi untuk pipeline kualitas produksi, beberapa pemikiran mungkin harus dimasukkan ke dalam tingkat dukungan untuk salah satu fitur ini yang didukung oleh PJRT API:

  • Ruang memori
  • Tata letak kustom
  • Operasi komunikasi seperti kirim/terima
  • Pemindahan beban host
  • Sharding

Komunikasi perangkat framework PJRT yang umum

Contoh Log

Berikut adalah log metode yang dipanggil untuk memuat plugin PJRT dan menjalankan y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Dalam hal ini, kita mencatat JAX yang berinteraksi dengan plugin PJRT Referensi StableHLO.

Contoh log

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)