Arka plan
PJRT, ML ekosistemine eklemek istediğimiz tek tip cihaz API'sidir. Uzun vadeli hedefimiz şudur:
- Çerçeveler (JAX, TF vb.), çerçeveler için şeffaf olmayan cihaza özel uygulamaları olan PJRT'yi çağırır;
- Her cihaz, PJRT API'lerini uygulamaya odaklanır ve çerçeveler için opak olabilir.
PJRT hem C API'si hem de C++ API'si sunar. Her iki katmana da bağlanabilirsiniz. C++ API, bazı kavramları soyutlamak için sınıfları kullanır ancak XLA veri türleriyle daha güçlü bağları da vardır. Bu sayfada C++ API'ye odaklanılmıştır.
PJRT Bileşenleri
PjRtClient
Tam referans için pjrt_client.h > PjRtClient
adresini ziyaret edin.
İstemciler, cihaz ile çerçeve arasındaki tüm iletişimi yönetir ve iletişimde kullanılan tüm durumu kapsar. PJRT eklentisiyle etkileşim kurmak için genel bir API grubuna sahiptirler ve belirli bir eklenti için cihazlara ve bellek alanlarına sahiptirler.
PjRtDevice
Tam referansları pjrt_client.h > PjRtDevice
ve pjrt_device_description.h
adreslerinde bulabilirsiniz.
Cihaz sınıfı, tek bir cihazı tanımlamak için kullanılır. Cihazların türü (GPU/CPU/xPU'yu tanımlamak için benzersiz karma oluşturma) ve hem yerel hem de küresel cihaz ızgarasındaki konumu hakkında bilgi veren bir cihaz açıklaması vardır.
Cihazlar, ilişkili bellek alanlarını ve sahip oldukları müşteriyi de bilir.
Cihaz, kendisiyle ilişkili gerçek verilerin arabelleklerini her zaman bilmeyebilir ancak ilişkili bellek alanlarını inceleyerek bu bilgiyi edinebilir.
PjRtMemorySpace
Tam referans için pjrt_client.h > PjRtMemorySpace
adresini ziyaret edin.
Bellek alanları, bellekteki bir konumu tanımlamak için kullanılabilir. Bu öğeler sabitlenebilir ve belirli bir cihazda yer almalıdır. Alternatif olarak, sabitlenmemiş olarak herhangi bir yerde yer alabilir ancak bir cihazdan erişilebilir olmalıdır.
Bellek alanları, ilişkili veri arabelleklerini, bir bellek alanının ilişkili olduğu cihazları (çoğul) ve parçası olduğu istemciyi bilir.
PjRtBuffer
Tam referans için pjrt_client.h > PjRtBuffer
adresini ziyaret edin.
Arabellek, cihazdaki verileri, MLIR öğeleri özelliği veya özel bir tensör biçimi gibi, eklenti içinde çalışmanın kolay olacağı bir biçimde tutar.
Bir çerçeve, verileri bir cihaza xla::Literal
biçiminde göndermeye çalışabilir. Yani, modüle gönderilecek bir giriş bağımsız değişkeni için cihazın belleğine klonlanması (veya ödünç alınması) gereken veriler gönderilebilir. Bir arabelleğe artık ihtiyaç duyulmadığında çerçeve, temizleme işlemi için Delete
yöntemini çağırır.
Tampon, parçası olduğu bellek alanını bilir ve hangi cihazların kendisine erişebileceğini dolaylı olarak anlayabilir ancak tamponlar cihazlarını her zaman bilmez.
Arabellekler, çerçevelerle iletişim kurmak için xla::Literal
türüne nasıl dönüştürüleceğini ve bu türden nasıl dönüştürüleceğini bilir:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Tampon oluşturmaya yönelik API'ler, ana tampondaki değişmez verilerin paylaşılıp paylaşılamayacağını, kopyalanıp kopyalanamayacağını veya değiştirilip değiştirilemeyeceğini belirlemeye yardımcı olan Tampon Anlamları'na sahiptir.
Son olarak, bir arabelleğin, yürütme kapsamından daha uzun süre dayanması gerekebilir. Bu durumda, arabellekler, arabelleğin tuttuğu verilere geçici olarak sahip olunan bir işaretçi sağlayan harici referanslar oluşturmanıza olanak tanır. Bu referanslar, temel verileri yorumlamak için meta veriler (dtype / boyut boyutları) içerir.x = jit(foo)(10)
PjRtCompiler
Tam referans için pjrt_compiler.h > PjRtCompiler
adresini ziyaret edin.
PjRtCompiler
sınıfı, XLA arka uçları için yararlı uygulama ayrıntıları sağlar ancak bir eklentinin uygulanması için gerekli değildir. Teorik olarak, bir PjRtCompiler
'ün veya PjRtClient::Compile
yönteminin sorumluluğu, bir giriş modülü alıp PjRtLoadedExecutable
döndürmektir.
PjRtExecutable / PjRtLoadedExecutable
Tam referans için pjrt_executable.h > PjRtExecutable
ve pjrt_client.h > PjRtLoadedExecutable
adresini ziyaret edin.
PjRtExecutable
, derlenmiş bir yapıyı ve yürütme seçeneklerini alıp bir yürütülebilir dosyanın gerektiği gibi depolanması ve yüklenmesi için bunları nasıl serileştireceğini/nesneleri serileştirip ayıklamayı bilir.
PjRtLoadedExecutable
, giriş bağımsız değişkenlerinin yürütülmeye hazır olduğu, bellekte derlenmiş yürütülebilir dosyadır ve PjRtExecutable
sınıfının alt sınıfıdır.
Yürütülebilir dosyalar, istemcinin Execute
yöntemlerinden biri aracılığıyla arayüz oluşturur:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Çerçeve, Execute
çağrılmadan önce gerekli tüm verileri yürüten istemcinin sahip olduğu PjRtBuffers
öğesine aktarır ancak çerçevenin referans olarak kullanması için döndürür. Bu tamponlar daha sonra Execute
yöntemine bağımsız değişken olarak sağlanır.
PJRT Kavramları
PjRtFutures ve Asynchrone Hesaplamalar
Bir eklentinin herhangi bir bölümü eşzamansız olarak uygulanıyorsa gelecekleri düzgün şekilde uygulamalıdır.
Aşağıdaki programı ele alalım:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Asenkron bir eklenti, x
hesaplamasını sıraya ekleyebilir ve hemen henüz okunmaya hazır olmayan ancak yürütme sırasında doldurulacak bir arabellek döndürebilir. Diğer PJRT cihazlarında yürütme dahil olmak üzere, x
'den sonra x
gerektirmeyen gerekli hesaplamaları yürütmeye devam edebilir. x
değerine ihtiyaç duyulduğunda, arabelleğin GetReadyFuture
tarafından döndürülen future aracılığıyla hazır olduğunu beyan edene kadar yürütme engellenir.
Gelecek sözleşmeleri, cihazlar ve arabellekler dahil olmak üzere bir nesnenin ne zaman kullanılabileceğini belirlemek için yararlı olabilir.
İleri düzey kavramlar
Temel API'leri uygulamanın ötesine geçmek, JAX'ın bir eklenti tarafından kullanılabilecek özelliklerini genişletir. Bu özelliklerin tümü, tipik JIT ve yürütme iş akışı bunlar olmadan çalışacağı için etkinleştirilmesi gereken özelliklerdir. Ancak üretim kalitesinde bir ardışık düzen için PJRT API'leri tarafından desteklenen bu özelliklerden herhangi biri için destek derecesi üzerinde düşünmeniz gerekebilir:
- Anı alanları
- Özel düzenler
- Gönderme/alma gibi iletişim işlemleri
- Ana makineden yük atma
- Parçalama
Tipik PJRT çerçevesi-cihaz iletişimi
Örnek Günlük
Aşağıda, PJRT eklentisini yüklemek ve y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
'ü yürütmek için çağrılan yöntemlerin günlüğü verilmiştir. Bu durumda, JAX'in StableHLO Reference PJRT eklentisiyle etkileşime geçtiğini günlüğe kaydederiz.
Örnek günlük
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)