PJRT C++ Device API'ye Genel Bakış

Arka plan

PJRT, makine öğrenimi ekosistemine eklemek istediğimiz tek tip Device API'dir. Uzun vadeli hedefimiz şudur:

  1. Çerçeveler (JAX, TF vb.), çerçeveler için şeffaf olmayan cihaza özgü uygulamalara sahip PJRT'yi çağırır.
  2. Her cihaz, PJRT API'lerini uygulamaya odaklanır ve çerçeveler için opak olabilir.

PJRT hem C API'si hem de C++ API'si sunar. Her iki katmana da bağlanabilirsiniz. C++ API, bazı kavramları soyutlamak için sınıfları kullanır ancak XLA veri türleriyle daha güçlü bağları da vardır. Bu sayfada C++ API'ye odaklanılmıştır.

PJRT Bileşenleri

PJRT Bileşenleri

PjRtClient

Tam referans için pjrt_client.h > PjRtClient adresini ziyaret edin.

İstemciler, cihaz ile çerçeve arasındaki tüm iletişimi yönetir ve iletişimde kullanılan tüm durumu kapsar. PJRT eklentisiyle etkileşim kurmak için genel bir API grubuna sahiptirler ve belirli bir eklenti için cihazlara ve bellek alanlarına sahiptirler.

PjRtDevice

Tam referansları pjrt_client.h > PjRtDevice ve pjrt_device_description.h adreslerinde bulabilirsiniz.

Tek bir cihazı tanımlamak için cihaz sınıfı kullanılır. Cihazların türü (GPU/CPU/xPU'yu tanımlamak için benzersiz karma oluşturma) ve hem yerel hem de küresel cihaz ızgarasındaki konumu hakkında bilgi veren bir cihaz açıklaması vardır.

Cihazlar, ilişkili bellek alanlarını ve sahip oldukları müşteriyi de bilir.

Bir cihaz, kendisiyle ilişkili gerçek verilerin tamponlarını bilemeyebilir ancak ilişkili bellek alanlarına bakarak bunu belirleyebilir.

PjRtMemorySpace

Tam referans için pjrt_client.h > PjRtMemorySpace adresini ziyaret edin.

Bellek alanları, bellekteki bir konumu tanımlamak için kullanılabilir. Bu öğeler sabitlenebilir ve belirli bir cihazda yer almalıdır. Alternatif olarak, sabitlenmemiş olarak herhangi bir yerde yer alabilir ancak bir cihazdan erişilebilir olmalıdır.

Bellek alanları, ilişkili veri arabelleklerini, bir bellek alanının ilişkili olduğu cihazları (çoğul) ve parçası olduğu istemciyi bilir.

PjRtBuffer

Tam referans için pjrt_client.h > PjRtBuffer adresini ziyaret edin.

Arabellek, cihazdaki verileri, MLIR öğeleri özelliği veya özel bir tensör biçimi gibi, eklenti içinde çalışmanın kolay olacağı bir biçimde tutar. Bir çerçeve, verileri bir cihaza xla::Literal biçiminde göndermeye çalışabilir. Yani, modüle gönderilecek bir giriş bağımsız değişkeni için cihazın belleğine klonlanması (veya ödünç alınması) gereken veriler gönderilebilir. Bir arabelleğe artık ihtiyaç duyulmadığında çerçeve, temizleme işlemi için Delete yöntemini çağırır.

Tampon, parçası olduğu bellek alanını bilir ve hangi cihazların kendisine erişebileceğini dolaylı olarak anlayabilir ancak tamponlar cihazlarını her zaman bilmez.

Arabellekler, çerçevelerle iletişim kurmak için xla::Literal türüne nasıl dönüştürüleceğini ve bu türden nasıl dönüştürüleceğini bilir:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Arabellek oluşturmaya yönelik API'lerde, ana makine arabelleğindeki değişmez verilerin paylaşılıp paylaşılamayacağını veya kopyalanıp kopyalanmayacağını belirlemeye yardımcı olan Arabellek Anlamları bulunur.

Son olarak, bir arabelleğin, yürütme kapsamından daha uzun süre dayanması gerekebilir. Bu durumda, arabellekler, arabelleğin tuttuğu verilere geçici olarak sahip olunan bir işaretçi sağlayan harici referanslar oluşturmanıza olanak tanır. Bu referanslar, temel verileri yorumlamak için meta veriler (dtype / boyut boyutları) içerir.x = jit(foo)(10)

PjRtCompiler

Tam referans için pjrt_compiler.h > PjRtCompiler adresini ziyaret edin.

PjRtCompiler sınıfı, XLA arka uçları için yararlı uygulama ayrıntıları sağlar ancak bir eklentinin uygulanması için gerekli değildir. Teorik olarak, bir PjRtCompiler'ün veya PjRtClient::Compile yönteminin sorumluluğu, bir giriş modülü alıp PjRtLoadedExecutable döndürmektir.

PjRtExecutable / PjRtLoadedExecutable

Tam referans için pjrt_executable.h > PjRtExecutable ve pjrt_client.h > PjRtLoadedExecutable adresini ziyaret edin.

PjRtExecutable, derlenmiş bir yapıyı ve yürütme seçeneklerini alıp bir yürütülebilir dosyanın gerektiği gibi depolanması ve yüklenmesi için bunları nasıl serileştireceğini/nesneleri serileştirip ayıklamayı bilir.

PjRtLoadedExecutable, giriş bağımsız değişkenlerinin yürütülmeye hazır olduğu, bellekte derlenmiş yürütülebilir dosyadır ve PjRtExecutable sınıfının alt sınıfıdır.

Yürütülebilir dosyalar, istemcinin Execute yöntemlerinden biri aracılığıyla arayüz oluşturur:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Çerçeve, Execute çağrılmadan önce gerekli tüm verileri yürüten istemcinin sahip olduğu PjRtBuffers öğesine aktarır ancak çerçevenin referans olarak kullanması için döndürür. Bu tamponlar daha sonra Execute yöntemine bağımsız değişken olarak sağlanır.

PJRT Kavramları

PjRtFutures ve Asynchrone Hesaplamalar

Bir eklentinin herhangi bir bölümü eşzamansız olarak uygulanıyorsa gelecekleri düzgün şekilde uygulamalıdır.

Aşağıdaki programı ele alalım:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Asenkron bir eklenti, x hesaplamasını sıraya ekleyebilir ve hemen henüz okunmaya hazır olmayan ancak yürütme sırasında doldurulacak bir arabellek döndürebilir. Diğer PJRT cihazlarında yürütme dahil olmak üzere, x'den sonra x gerektirmeyen gerekli hesaplamaları yürütmeye devam edebilir. x değerine ihtiyaç duyulduğunda, arabelleğin GetReadyFuture tarafından döndürülen future aracılığıyla hazır olduğunu beyan edene kadar yürütme engellenir.

Gelecek sözleşmeleri, cihazlar ve arabellekler dahil olmak üzere bir nesnenin ne zaman kullanılabileceğini belirlemek için yararlı olabilir.

İleri düzey kavramlar

Temel API'leri uygulamanın ötesine geçmek, JAX'ın bir eklenti tarafından kullanılabilecek özelliklerini genişletir. Bunların hepsi, tipik bir JIT ve yürütme iş akışı olduğunda bunlar olmadan da çalışacak şekilde isteğe bağlı özelliklerdir, ancak üretim kalitesinde bir ardışık düzen için PJRT API'leri tarafından desteklenen şu özelliklerin herhangi biri için destek düzeyi üzerine düşünmek çok önemlidir:

  • Anı alanları
  • Özel düzenler
  • Gönderme/alma gibi iletişim işlemleri
  • Ana makineyi boşaltma
  • Parçalama

Tipik PJRT çerçevesi-cihaz iletişimi

Örnek Günlük

Aşağıda, PJRT eklentisini yüklemek ve y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)'ü yürütmek için çağrılan yöntemlerin günlüğü verilmiştir. Bu durumda, JAX'in StableHLO Reference PJRT eklentisiyle etkileşime geçtiğini günlüğe kaydederiz.

Örnek günlük

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)