Arka plan
PJRT, makine öğrenimi ekosistemine eklemek istediğimiz tek tip Cihaz API'sidir. Uzun vadeli vizyonumuz şudur:
- Framework'ler (JAX, TF vb.), PJRT'yi çağırır. PJRT'nin, framework'ler için opak olan cihaza özel uygulamaları vardır.
- Her cihaz, PJRT API'lerini uygulamaya odaklanır ve çerçeveler için opak olabilir.
PJRT hem C API'si hem de C++ API'si sunar. Her iki katmanda da eklenti kullanabilirsiniz. C++ API, bazı kavramları soyutlamak için sınıfları kullanır ancak XLA veri türleriyle de daha güçlü bağları vardır. Bu sayfada C++ API'si ele alınmaktadır.
PJRT Bileşenleri
PjRtClient
Tam referansı pjrt_client.h > PjRtClient adresinde bulabilirsiniz.
İstemciler, cihaz ile çerçeve arasındaki tüm iletişimi yönetir ve iletişimde kullanılan tüm durumu kapsar. PJRT eklentisiyle etkileşim kurmak için genel bir API kümesine sahiptirler ve belirli bir eklenti için cihazların ve bellek alanlarının sahibidirler.
PjRtDevice
Tam referanslar pjrt_client.h > PjRtDevice ve pjrt_device_description.h adreslerinde
Cihaz sınıfı, tek bir cihazı tanımlamak için kullanılır. Bir cihazın, türünü (GPU/CPU/xPU'yu tanımlayan benzersiz karma) ve hem yerel hem de global olarak bir cihaz ızgarasındaki konumunu belirlemeye yardımcı olan bir cihaz açıklaması vardır.
Cihazlar, ilişkili bellek alanlarını ve sahibi olan istemciyi de bilir.
Bir cihaz, kendisiyle ilişkili gerçek verilerin arabelleklerini bilmeyebilir ancak ilişkili bellek alanlarına bakarak bunu anlayabilir.
PjRtMemorySpace
Tam referansı pjrt_client.h > PjRtMemorySpace adresinde bulabilirsiniz.
Bellek alanları, bellek konumunu tanımlamak için kullanılabilir. Bunlar sabitlenmemiş olabilir ve herhangi bir cihazdan erişilebilmekle birlikte herhangi bir yerde bulunabilir ya da sabitlenmiş olabilir ve belirli bir cihazda bulunması gerekir.
Bellek alanları, ilişkili veri arabelleklerini, bir bellek alanının ilişkili olduğu cihazları (çoğul) ve parçası olduğu istemciyi bilir.
PjRtBuffer
Tam referansı pjrt_client.h > PjRtBuffer adresinde bulabilirsiniz.
Arabellek, bir cihazdaki verileri eklentide çalışmayı kolaylaştıracak bir biçimde (ör. MLIR öğeleri attr veya özel tensör biçimi) tutar.
Bir çerçeve, verileri bir cihaza xla::Literal biçiminde göndermeye çalışabilir. Örneğin, modülün giriş bağımsız değişkeni için verilerin cihaza kopyalanması (veya ödünç alınması) gerekir. Bir arabellek artık gerekli olmadığında, temizleme işlemi için çerçeve tarafından Delete yöntemi çağrılır.
Bir arabellek, parçası olduğu bellek alanını bilir ve hangi cihazların buna erişebileceğini dolaylı olarak anlayabilir. Ancak arabellekler, cihazlarını mutlaka bilmez.
Tamponlar, çerçevelerle iletişim kurmak için xla::Literal türüne ve bu türden nasıl dönüştürüleceğini bilir:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Tampon oluşturmaya yönelik API'ler, ana makine tamponundaki değişmez verilerin paylaşılıp paylaşılamayacağını, kopyalanıp kopyalanamayacağını veya değiştirilip değiştirilemeyeceğini belirlemeye yardımcı olan Tampon Semantiği'ne sahiptir.
Son olarak, arabellek çerçeve katmanındaki bir değişkene x = jit(foo)(10) atanırsa yürütme kapsamından daha uzun süre dayanması gerekebilir. Bu durumlarda arabellekler, arabellek tarafından tutulan verilere geçici olarak sahip olunan bir işaretçi sağlayan harici referanslar oluşturmaya olanak tanır. Ayrıca, temel verileri yorumlamak için meta veriler (dtype / dim boyutları) da sağlanır.
PjRtCompiler
Tam referansı pjrt_compiler.h > PjRtCompiler adresinde bulabilirsiniz.
PjRtCompiler sınıfı, XLA arka uçları için faydalı uygulama ayrıntıları sağlar ancak bir eklentinin uygulanması için gerekli değildir. Teoride, PjRtCompiler veya PjRtClient::Compile yönteminin sorumluluğu bir giriş modülünü alıp PjRtLoadedExecutable döndürmektir.
PjRtExecutable / PjRtLoadedExecutable
Tam referans için pjrt_executable.h > PjRtExecutable ve pjrt_client.h > PjRtLoadedExecutable adreslerini ziyaret edin.
PjRtExecutable, derlenmiş bir yapıyı ve yürütme seçeneklerini nasıl alacağını ve bunları nasıl serileştirip/seri durumdan çıkaracağını bilir. Böylece yürütülebilir bir dosya gerektiğinde depolanabilir ve yüklenebilir.
PjRtLoadedExecutable, yürütülecek giriş bağımsız değişkenleri için hazır olan, bellekte derlenmiş yürütülebilir dosyadır ve PjRtExecutable sınıfının alt sınıfıdır.
Çalıştırılabilir dosyalar, istemcinin Execute yöntemlerinden biri aracılığıyla arayüzlenir:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Execute çağrılmadan önce çerçeve, gerekli tüm verileri yürütme istemcisine ait olan PjRtBuffers öğesine aktarır ancak çerçeve tarafından başvurulmak üzere döndürülür. Bu arabellekler daha sonra Execute yöntemine bağımsız değişken olarak sağlanır.
PJRT Concepts
Gelecekler ve Eş Zamansız Hesaplamalar
Bir eklentinin herhangi bir bölümü eşzamansız olarak uygulanıyorsa gelecekteki işlevleri düzgün bir şekilde uygulamalıdır.
Aşağıdaki programı inceleyin:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Asenkron bir eklenti, hesaplamayı x sıraya alabilir ve hemen okunmaya hazır olmayan bir arabellek döndürebilir ancak yürütme işlemi arabelleği doldurur. Yürütme, x gerektirmeyen gerekli hesaplamaları x sonrasında da sıraya almaya devam edebilir. Buna diğer PJRT cihazlarında yürütme de dahildir. x değeri gerektiğinde, arabellek GetReadyFuture tarafından döndürülen gelecekteki değer aracılığıyla hazır olduğunu bildirene kadar yürütme engellenir.
Gelecekler, cihazlar ve arabellekler de dahil olmak üzere bir nesnenin ne zaman kullanıma sunulacağını belirlemek için yararlı olabilir.
İleri düzey kavramlar
Temel API'lerin uygulanmasının ötesine geçmek, bir eklenti tarafından kullanılabilecek JAX özelliklerini genişletir. Bunların tümü, normal bir JIT ve yürütme iş akışının bu özellikler olmadan çalışacağı anlamında isteğe bağlı özelliklerdir. Ancak üretim kalitesinde bir işlem hattı için PJRT API'leri tarafından desteklenen bu özelliklerin destek derecesi hakkında düşünülmesi gerekir:
- Bellek alanları
- Özel düzenler
- Gönderme/alma gibi iletişim işlemleri
- Ana makine yükünü azaltma
- Parçalama
Tipik PJRT çerçeve-cihaz iletişimi
Örnek günlük
Aşağıda, PJRT eklentisini yüklemek ve y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) yürütmek için çağrılan yöntemlerin günlüğü verilmiştir. Bu durumda, StableHLO Reference PJRT eklentisiyle etkileşimde bulunan JAX'i günlüğe kaydederiz.
Örnek günlük
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)