Ringkasan PJRT C++ Device API

Latar belakang

PJRT adalah Device API seragam yang ingin kita tambahkan ke ekosistem ML. Visi jangka panjangnya adalah:

  1. Framework (JAX, TF, dll.) akan memanggil PJRT, yang memiliki implementasi khusus perangkat yang tidak transparan bagi framework;
  2. Setiap perangkat berfokus pada penerapan API PJRT, dan dapat bersifat buram bagi framework.

PJRT menawarkan C API dan C++ API. Tidak masalah jika Anda menggunakan lapisan mana pun, C++ API menggunakan class untuk mengabstraksi beberapa konsep, tetapi juga memiliki hubungan yang lebih kuat dengan tipe data XLA. Halaman ini berfokus pada C++ API.

Komponen PJRT

Komponen PJRT

PjRtClient

Referensi lengkap di pjrt_client.h > PjRtClient.

Klien mengelola semua komunikasi antara perangkat dan framework, serta merangkum semua status yang digunakan dalam komunikasi. Mereka memiliki serangkaian API umum untuk berinteraksi dengan plugin PJRT, dan mereka memiliki perangkat dan ruang memori untuk plugin tertentu.

PjRtDevice

Referensi lengkap di pjrt_client.h > PjRtDevice, dan pjrt_device_description.h

Class perangkat digunakan untuk mendeskripsikan satu perangkat. Perangkat memiliki deskripsi perangkat untuk membantu mengidentifikasi jenisnya (hash unik untuk mengidentifikasi GPU/CPU/xPU), dan lokasinya dalam petak perangkat baik secara lokal maupun global.

Perangkat juga mengetahui ruang memori terkait dan klien yang memilikinya.

Perangkat tidak selalu mengetahui buffer data sebenarnya yang terkait dengannya, tetapi perangkat dapat mengetahuinya dengan melihat ruang memori yang terkait.

PjRtMemorySpace

Referensi lengkap di pjrt_client.h > PjRtMemorySpace.

Ruang memori dapat digunakan untuk menjelaskan lokasi memori. Item ini dapat dilepas dan bebas berada di mana saja, tetapi dapat diakses dari perangkat, atau item ini dapat disematkan dan harus berada di perangkat tertentu.

Ruang memori mengetahui buffer data terkaitnya, dan perangkat (jamak) yang terkait dengan ruang memori, serta klien yang menjadi bagiannya.

PjRtBuffer

Referensi lengkap di pjrt_client.h > PjRtBuffer.

Buffer menyimpan data di perangkat dalam beberapa format yang akan mudah digunakan di dalam plugin, seperti atribut elemen MLIR atau format tensor eksklusif. Framework dapat mencoba mengirim data ke perangkat dalam bentuk xla::Literal, yaitu untuk argumen input ke modul, yang harus di-clone (atau dipinjam), ke memori perangkat. Setelah buffer tidak lagi diperlukan, metode Delete dipanggil oleh framework untuk membersihkan.

Buffer mengetahui ruang memori yang menjadi bagiannya, dan secara transitif dapat mengetahui perangkat mana yang dapat mengaksesnya, tetapi buffer tidak selalu mengetahui perangkatnya.

Untuk berkomunikasi dengan framework, buffer mengetahui cara mengonversi ke dan dari jenis xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API untuk membuat buffer memiliki Semantik Buffer yang membantu menentukan apakah data literal dari buffer host dapat dibagikan atau disalin atau diubah.

Terakhir, buffer mungkin perlu bertahan lebih lama daripada cakupan eksekusinya, jika ditetapkan ke variabel di lapisan framework x = jit(foo)(10). Dalam kasus ini, buffer memungkinkan pembuatan referensi eksternal yang menyediakan pointer yang dimiliki sementara ke data yang disimpan oleh buffer, beserta metadata (dtype / ukuran dim) untuk menafsirkan data pokok.

PjRtCompiler

Referensi lengkap di pjrt_compiler.h > PjRtCompiler.

Class PjRtCompiler memberikan detail penerapan yang berguna untuk backend XLA, tetapi tidak diperlukan agar plugin dapat diterapkan. Secara teori, tanggung jawab PjRtCompiler, atau metode PjRtClient::Compile, adalah mengambil modul input dan menampilkan PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Referensi lengkap di pjrt_executable.h > PjRtExecutable, dan pjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable tahu cara mengambil artefak yang dikompilasi dan opsi eksekusi serta melakukan serialisasi/deserialisasi sehingga file yang dapat dieksekusi dapat disimpan dan dimuat sesuai kebutuhan.

PjRtLoadedExecutable adalah executable yang dikompilasi dalam memori yang siap untuk dieksekusi dengan argumen input, dan merupakan subclass dari PjRtExecutable.

Dapat dieksekusi berinteraksi melalui salah satu metode Execute klien:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Sebelum memanggil Execute, framework akan mentransfer semua data yang diperlukan ke PjRtBuffers yang dimiliki oleh klien yang menjalankan, tetapi ditampilkan untuk dirujuk oleh framework. Buffer ini kemudian diberikan sebagai argumen ke metode Execute.

Konsep PJRT

Future & Komputasi Asinkron

Jika bagian plugin diimplementasikan secara asinkron, plugin tersebut harus menerapkan future dengan benar.

Perhatikan program berikut:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Plugin asinkron akan dapat mengantrekan komputasi x, dan segera mengembalikan buffer yang belum siap dibaca, tetapi eksekusi akan mengisinya. Eksekusi dapat terus mengantrekan komputasi yang diperlukan setelah x, yang tidak memerlukan x, termasuk eksekusi di perangkat PJRT lainnya. Setelah nilai x diperlukan, eksekusi akan diblokir hingga buffer menyatakan dirinya siap melalui future yang ditampilkan oleh GetReadyFuture.

Future dapat berguna untuk menentukan kapan suatu objek tersedia, termasuk perangkat dan buffer.

Konsep lanjutan

Melampaui penerapan API dasar akan memperluas fitur JAX yang dapat digunakan oleh plugin. Semua fitur ini bersifat opsional dalam arti bahwa alur kerja JIT dan eksekusi yang umum akan berfungsi tanpa fitur tersebut, tetapi untuk pipeline kualitas produksi, sebaiknya pertimbangkan tingkat dukungan untuk salah satu fitur yang didukung oleh PJRT API ini:

  • Ruang memori
  • Tata letak kustom
  • Operasi komunikasi seperti mengirim/menerima
  • Pengalihan host
  • Sharding

Komunikasi perangkat framework PJRT umum

Contoh Log

Berikut adalah log metode yang dipanggil untuk memuat plugin PJRT dan mengeksekusi y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Dalam kasus ini, kita mencatat interaksi JAX dengan plugin PJRT Referensi StableHLO.

Contoh log

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)