Latar belakang
PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjangnya adalah:
- Framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang tidak transparan untuk framework;
- Setiap perangkat berfokus pada penerapan PJRT API, dan dapat buram untuk framework.
PJRT menawarkan C API dan C++ API. Menyambungkan di salah satu lapisan tidak masalah, API C++ menggunakan class untuk memisahkan beberapa konsep, tetapi juga memiliki hubungan yang lebih kuat dengan jenis data XLA. Halaman ini berfokus pada C++ API.
Komponen PJRT
PjRtClient
Referensi lengkap di pjrt_client.h > PjRtClient
.
Klien mengelola semua komunikasi antara perangkat dan framework, serta mengaitkan semua status yang digunakan dalam komunikasi. Plugin ini memiliki kumpulan API umum untuk berinteraksi dengan plugin PJRT, dan memiliki perangkat serta ruang memori untuk plugin tertentu.
PjRtDevice
Referensi lengkap di pjrt_client.h > PjRtDevice
,
dan pjrt_device_description.h
Class perangkat digunakan untuk mendeskripsikan satu perangkat. Perangkat memiliki deskripsi perangkat untuk membantu mengidentifikasi jenisnya (hash unik untuk mengidentifikasi GPU/CPU/xPU), dan lokasi dalam petak perangkat secara lokal dan global.
Perangkat juga mengetahui ruang memori terkait dan klien yang memilikinya.
Perangkat tidak harus mengetahui buffering data sebenarnya yang terkait dengan perangkat, tetapi perangkat dapat mengetahuinya dengan melihat ruang memori terkait.
PjRtMemorySpace
Referensi lengkap di pjrt_client.h > PjRtMemorySpace
.
Ruang memori dapat digunakan untuk menjelaskan lokasi memori. Aplikasi ini dapat di-unpin, dan dapat ditayangkan di mana saja, tetapi dapat diakses dari perangkat, atau dapat disematkan dan harus ditayangkan di perangkat tertentu.
Ruang memori mengetahui buffer data terkait, dan perangkat (jamak) yang terkait dengan ruang memori, serta klien yang menjadi bagiannya.
PjRtBuffer
Referensi lengkap di pjrt_client.h > PjRtBuffer
.
Buffer menyimpan data di perangkat dalam beberapa format yang akan mudah digunakan
di dalam plugin, seperti atribut elemen MLIR atau format tensor eksklusif.
Framework dapat mencoba mengirim data ke perangkat dalam bentuk xla::Literal
,
yaitu untuk argumen input ke modul, yang harus di-clone (atau dipinjam), ke
memori perangkat. Setelah buffer tidak lagi diperlukan, metode Delete
akan dipanggil oleh framework untuk membersihkan.
Buffer mengetahui ruang memori yang menjadi bagiannya, dan secara transitif dapat mengetahui perangkat mana yang dapat mengaksesnya, tetapi buffer tidak selalu mengetahui perangkatnya.
Untuk berkomunikasi dengan framework, buffering mengetahui cara mengonversi ke dan dari
jenis xla::Literal
:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
API untuk membuat buffering memiliki Semantik Buffer yang membantu menentukan apakah data literal dari buffering host dapat dibagikan atau disalin atau diubah.
Terakhir, buffering mungkin perlu bertahan lebih lama dari cakupan eksekusinya, jika
ditetapkan ke variabel di lapisan framework x = jit(foo)(10)
, dalam hal ini
buffer memungkinkan pembuatan referensi eksternal yang memberikan pointer
yang dimiliki sementara ke data yang disimpan oleh buffering, bersama dengan metadata (ukuran dtype / dim)
untuk menafsirkan data pokok.
PjRtCompiler
Referensi lengkap di pjrt_compiler.h > PjRtCompiler
.
Class PjRtCompiler
memberikan detail implementasi yang berguna untuk backend
XLA, tetapi tidak diperlukan untuk implementasi plugin. Secara teori, tanggung jawab PjRtCompiler
, atau metode PjRtClient::Compile
, adalah
mengambil modul input dan menampilkan PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
Referensi lengkap di pjrt_executable.h > PjRtExecutable
,
dan pjrt_client.h > PjRtLoadedExecutable
.
PjRtExecutable
mengetahui cara mengambil artefak yang dikompilasi dan opsi eksekusi
serta melakukan serialisasi/deserialisasi sehingga file yang dapat dieksekusi dapat disimpan dan dimuat
sesuai kebutuhan.
PjRtLoadedExecutable
adalah file yang dapat dieksekusi yang dikompilasi dalam memori dan siap
untuk dieksekusi argumen input, yang merupakan subclass dari PjRtExecutable
.
File yang dapat dieksekusi dihubungkan melalui salah satu metode Execute
klien:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Sebelum memanggil Execute
, framework akan mentransfer semua data yang diperlukan ke
PjRtBuffers
yang dimiliki oleh klien yang menjalankan, tetapi ditampilkan untuk direferensikan
oleh framework. Buffer ini kemudian diberikan sebagai argumen ke metode Execute
.
Konsep PJRT
PjRtFutures & Komputasi Asinkron
Jika ada bagian plugin yang diterapkan secara asinkron, bagian tersebut harus menerapkan future dengan benar.
Pertimbangkan program berikut:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Plugin asinkron akan dapat mengantrekan komputasi x
, dan langsung
menampilkan buffer yang belum siap dibaca, tetapi eksekusi akan mengisinya. Eksekusi dapat terus mengantrekan komputasi yang diperlukan setelah x
, yang
tidak memerlukan x
, termasuk eksekusi di perangkat PJRT lainnya. Setelah nilai
x
diperlukan, eksekusi akan diblokir hingga buffering menyatakan dirinya siap melalui
masa mendatang yang ditampilkan oleh GetReadyFuture
.
Futures dapat berguna untuk menentukan kapan objek tersedia, termasuk perangkat dan buffering.
Konsep lanjutan
Memperluas selain menerapkan API dasar akan memperluas fitur JAX yang dapat digunakan oleh plugin. Semua ini adalah fitur keikutsertaan dalam arti bahwa pada JIT dan alur kerja eksekusi standar akan berfungsi tanpanya, tetapi untuk pipeline kualitas produksi, beberapa pemikiran mungkin harus dimasukkan ke dalam tingkat dukungan untuk salah satu fitur ini yang didukung oleh PJRT API:
- Ruang memori
- Tata letak kustom
- Operasi komunikasi seperti kirim/terima
- Pemindahan beban host
- Sharding
Komunikasi perangkat framework PJRT yang umum
Contoh Log
Berikut adalah log metode yang dipanggil untuk memuat plugin PJRT dan
menjalankan y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. Dalam hal ini,
kita mencatat JAX yang berinteraksi dengan plugin PJRT Referensi StableHLO.
Contoh log
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)