Latar belakang
PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjangnya adalah:
- Framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang tidak transparan bagi framework;
- Setiap perangkat berfokus pada penerapan API PJRT, dan dapat bersifat buram bagi framework.
PJRT menawarkan C API dan C++ API. Tidak masalah jika Anda menggunakan lapisan mana pun, C++ API menggunakan class untuk mengabstraksi beberapa konsep, tetapi juga memiliki hubungan yang lebih kuat dengan tipe data XLA. Halaman ini berfokus pada C++ API.
Komponen PJRT
PjRtClient
Referensi lengkap di pjrt_client.h > PjRtClient.
Klien mengelola semua komunikasi antara perangkat dan framework, serta merangkum semua status yang digunakan dalam komunikasi. Mereka memiliki serangkaian API umum untuk berinteraksi dengan plugin PJRT, dan mereka memiliki perangkat dan ruang memori untuk plugin tertentu.
PjRtDevice
Referensi lengkap di pjrt_client.h > PjRtDevice,
dan pjrt_device_description.h
Class perangkat digunakan untuk mendeskripsikan satu perangkat. Perangkat memiliki deskripsi perangkat untuk membantu mengidentifikasi jenisnya (hash unik untuk mengidentifikasi GPU/CPU/xPU), dan lokasinya dalam petak perangkat baik secara lokal maupun global.
Perangkat juga mengetahui ruang memori terkait dan klien yang memilikinya.
Perangkat tidak selalu mengetahui buffer data sebenarnya yang terkait dengannya, tetapi perangkat dapat mengetahuinya dengan melihat ruang memori yang terkait.
PjRtMemorySpace
Referensi lengkap di pjrt_client.h > PjRtMemorySpace.
Ruang memori dapat digunakan untuk menjelaskan lokasi memori. Item ini dapat dilepas dan bebas berada di mana saja, tetapi dapat diakses dari perangkat, atau item ini dapat disematkan dan harus berada di perangkat tertentu.
Ruang memori mengetahui buffer data terkaitnya, dan perangkat (jamak) yang terkait dengan ruang memori, serta klien yang menjadi bagiannya.
PjRtBuffer
Referensi lengkap di pjrt_client.h > PjRtBuffer.
Buffer menyimpan data di perangkat dalam beberapa format yang akan mudah digunakan
di dalam plugin, seperti atribut elemen MLIR atau format tensor eksklusif.
Framework dapat mencoba mengirim data ke perangkat dalam bentuk xla::Literal, yaitu untuk argumen input ke modul, yang harus di-clone (atau dipinjam), ke memori perangkat. Setelah buffer tidak lagi diperlukan, metode Delete dipanggil oleh framework untuk membersihkan.
Buffer mengetahui ruang memori yang menjadi bagiannya, dan secara transitif dapat mengetahui perangkat mana yang dapat mengaksesnya, tetapi buffer tidak selalu mengetahui perangkatnya.
Untuk berkomunikasi dengan framework, buffer mengetahui cara mengonversi ke dan dari jenis
xla::Literal:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
API untuk membuat buffer memiliki Semantik Buffer yang membantu menentukan apakah data literal dari buffer host dapat dibagikan atau disalin atau diubah.
Terakhir, buffer mungkin perlu bertahan lebih lama daripada cakupan eksekusinya, jika ditetapkan ke variabel di lapisan framework x = jit(foo)(10). Dalam kasus ini, buffer memungkinkan pembuatan referensi eksternal yang menyediakan pointer yang dimiliki sementara ke data yang disimpan oleh buffer, beserta metadata (dtype / ukuran dim) untuk menafsirkan data pokok.
PjRtCompiler
Referensi lengkap di pjrt_compiler.h > PjRtCompiler.
Class PjRtCompiler memberikan detail penerapan yang berguna untuk backend XLA, tetapi tidak diperlukan agar plugin dapat diterapkan. Secara teori, tanggung jawab PjRtCompiler, atau metode PjRtClient::Compile, adalah mengambil modul input dan menampilkan PjRtLoadedExecutable.
PjRtExecutable / PjRtLoadedExecutable
Referensi lengkap di pjrt_executable.h > PjRtExecutable,
dan pjrt_client.h > PjRtLoadedExecutable.
PjRtExecutable tahu cara mengambil artefak yang dikompilasi dan opsi eksekusi
serta melakukan serialisasi/deserialisasi sehingga file yang dapat dieksekusi dapat disimpan dan dimuat sesuai
kebutuhan.
PjRtLoadedExecutable adalah executable yang dikompilasi dalam memori yang siap
untuk dieksekusi dengan argumen input, dan merupakan subclass dari PjRtExecutable.
Dapat dieksekusi berinteraksi melalui salah satu metode Execute klien:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Sebelum memanggil Execute, framework akan mentransfer semua data yang diperlukan ke
PjRtBuffers yang dimiliki oleh klien yang menjalankan, tetapi ditampilkan untuk dirujuk oleh
framework. Buffer ini kemudian diberikan sebagai argumen ke metode Execute.
Konsep PJRT
Future & Komputasi Asinkron
Jika bagian plugin diimplementasikan secara asinkron, plugin tersebut harus menerapkan future dengan benar.
Perhatikan program berikut:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Plugin asinkron akan dapat mengantrekan komputasi x, dan segera
mengembalikan buffer yang belum siap dibaca, tetapi eksekusi akan mengisinya. Eksekusi dapat terus mengantrekan komputasi yang diperlukan setelah x, yang tidak memerlukan x, termasuk eksekusi di perangkat PJRT lainnya. Setelah nilai
x diperlukan, eksekusi akan diblokir hingga buffer menyatakan dirinya siap melalui
future yang ditampilkan oleh GetReadyFuture.
Future dapat berguna untuk menentukan kapan suatu objek tersedia, termasuk perangkat dan buffer.
Konsep lanjutan
Melampaui penerapan API dasar akan memperluas fitur JAX yang dapat digunakan oleh plugin. Semua fitur ini bersifat opsional dalam arti bahwa alur kerja JIT dan eksekusi yang umum akan berfungsi tanpa fitur tersebut, tetapi untuk pipeline kualitas produksi, sebaiknya pertimbangkan tingkat dukungan untuk salah satu fitur yang didukung oleh PJRT API ini:
- Ruang memori
- Tata letak kustom
- Operasi komunikasi seperti mengirim/menerima
- Pengalihan host
- Sharding
Komunikasi perangkat framework PJRT umum
Contoh Log
Berikut adalah log metode yang dipanggil untuk memuat plugin PJRT dan
mengeksekusi y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Dalam kasus ini, kita mencatat interaksi JAX dengan plugin PJRT Referensi StableHLO.
Contoh log
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)