Hintergrund
PJRT ist die einheitliche Device API, die wir dem ML-System hinzufügen möchten. Langfristig soll Folgendes erreicht werden:
- Frameworks (JAX, TF usw.) rufen PJRT auf, das gerätespezifische Implementierungen hat, die für die Frameworks nicht transparent sind.
- Bei jedem Gerät liegt der Schwerpunkt auf der Implementierung von PJRT APIs und es kann für die Frameworks undurchsichtig sein.
PJRT bietet sowohl eine C API als auch eine C++ API. Die Einbindung in beiden Schichten ist in Ordnung. Die C++ API verwendet Klassen, um einige Konzepte zu abstrahieren, hat aber auch engere Verbindungen zu XLA-Datentypen. Auf dieser Seite geht es um die C++ API.
PJRT-Komponenten
PjRtClient
Vollständige Referenz unter pjrt_client.h > PjRtClient
Clients verwalten die gesamte Kommunikation zwischen dem Gerät und dem Framework und kapseln den gesamten in der Kommunikation verwendeten Status ein. Sie haben eine Reihe von generischen APIs für die Interaktion mit einem PJRT-Plug-in und sind für die Geräte und Speicherbereiche eines bestimmten Plug-ins verantwortlich.
PjRtDevice
Vollständige Referenzen unter pjrt_client.h > PjRtDevice
und pjrt_device_description.h
Eine Geräteklasse wird verwendet, um ein einzelnes Gerät zu beschreiben. Ein Gerät hat eine Gerätebeschreibung, die dazu beiträgt, seine Art zu identifizieren (eindeutiger Hash zur Identifizierung von GPU/CPU/xPU) und seinen Standort innerhalb eines Geräterasters sowohl lokal als auch global.
Geräte kennen auch ihre zugehörigen Speicherbereiche und den zugehörigen Client.
Ein Gerät kennt nicht unbedingt die zugehörigen Buffers mit den tatsächlichen Daten, kann dies aber durch Durchsuchen der zugehörigen Speicherbereiche herausfinden.
PjRtMemorySpace
Vollständige Referenz finden Sie unter pjrt_client.h > PjRtMemorySpace
.
Speicherbereiche können verwendet werden, um einen Speicherort zu beschreiben. Diese können entweder losgelöst werden und sind überall verfügbar, aber von einem Gerät aus zugänglich, oder sie können angepinnt werden und müssen auf einem bestimmten Gerät verfügbar sein.
Speicherbereiche kennen ihre zugehörigen Datenbuffer, die Geräte (im Plural), mit denen ein Speicherbereich verknüpft ist, sowie den Client, zu dem er gehört.
PjRtBuffer
Vollständige Referenz finden Sie unter pjrt_client.h > PjRtBuffer
.
Ein Zwischenspeicher enthält Daten auf einem Gerät in einem Format, mit dem innerhalb des Plug-ins leicht zu arbeiten ist, z. B. ein MLIR-Element oder ein proprietäres Tensorformat.
Ein Framework kann versuchen, Daten in Form einer xla::Literal
an ein Gerät zu senden, d.h. für ein Eingabeargument für das Modul, das geklont (oder ausgeliehen) werden muss, um in den Arbeitsspeicher des Geräts zu gelangen. Sobald ein Puffer nicht mehr benötigt wird, wird die Methode Delete
vom Framework aufgerufen, um ihn zu bereinigen.
Ein Puffer kennt den Speicherplatz, zu dem er gehört, und kann indirekt ermitteln, welche Geräte darauf zugreifen können. Puffer kennen jedoch nicht unbedingt ihre Geräte.
Für die Kommunikation mit Frameworks wissen Puffer, wie sie in einen und aus einem xla::Literal
-Typ konvertiert werden:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
APIs zum Erstellen eines Buffers haben Buffer-Semantiken, die festlegen, ob Literaldaten aus dem Host-Puffer freigegeben, kopiert oder verändert werden können.
Schließlich kann ein Puffer länger als der Umfang seiner Ausführung dauern, wenn er einer Variablen in der Framework-Ebene x = jit(foo)(10)
zugewiesen ist. In diesen Fällen ermöglichen Puffer externe Verweise, die einen temporären Verweis auf die vom Puffer gehaltenen Daten sowie Metadaten (Datentyp / Dim-Größen) zur Interpretation der zugrunde liegenden Daten bereitstellen.
PjRtCompiler
Vollständige Referenz finden Sie unter pjrt_compiler.h > PjRtCompiler
.
Die Klasse PjRtCompiler
bietet nützliche Implementierungsdetails für XLA-Back-Ends, ist aber für die Implementierung eines Plug-ins nicht erforderlich. Theoretisch besteht die Aufgabe eines PjRtCompiler
oder der PjRtClient::Compile
-Methode darin, ein Eingabemodul zu übernehmen und einen PjRtLoadedExecutable
zurückzugeben.
PjRtExecutable / PjRtLoadedExecutable
Vollständige Referenz unter pjrt_executable.h > PjRtExecutable
und pjrt_client.h > PjRtLoadedExecutable
.
Ein PjRtExecutable
kann ein kompiliertes Artefakt und Ausführungsoptionen serialisieren/deserialisieren, damit eine ausführbare Datei bei Bedarf gespeichert und geladen werden kann.
PjRtLoadedExecutable
ist die im Speicher kompilierte ausführbare Datei, die Eingabeargumente ausführen kann. Sie ist eine abgeleitete Klasse von PjRtExecutable
.
Die Kommunikation mit ausführbaren Dateien erfolgt über eine der Execute
-Methoden des Clients:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Vor dem Aufruf von Execute
überträgt das Framework alle erforderlichen Daten an PjRtBuffers
, die dem ausführenden Client gehören, aber für das Framework zurückgegeben werden. Diese Puffer werden dann als Argumente an die Execute
-Methode übergeben.
PJRT Concepts
PjRtFutures und asynchrone Berechnungen
Wenn ein Teil eines Plug-ins asynchron implementiert ist, müssen Futures ordnungsgemäß implementiert werden.
Betrachten Sie das folgende Programm:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Ein asynchrones Plug-in kann die Berechnung x
in die Warteschlange stellen und sofort einen Puffer zurückgeben, der noch nicht zum Lesen bereit ist, aber bei der Ausführung gefüllt wird. Nach x
können weiterhin erforderliche Berechnungen in die Warteschlange gestellt werden, für die keine x
erforderlich ist, einschließlich der Ausführung auf anderen PJRT-Geräten. Sobald der Wert von x
benötigt wird, wird die Ausführung blockiert, bis der Puffer sich über die von GetReadyFuture
zurückgegebene Future als bereit erklärt.
Futures können hilfreich sein, um zu ermitteln, wann ein Objekt verfügbar ist, einschließlich Geräten und Puffern.
Erweiterte Konzepte
Wenn Sie über die Implementierung der Basis-APIs hinausgehen, werden die Funktionen von JAX erweitert, die von einem Plug-in verwendet werden können. Diese Funktionen sind optional, da ein typischer JIT- und Ausführungsablauf auch ohne sie funktioniert. Bei einer Pipeline mit Produktionsqualität sollten Sie jedoch über den Grad der Unterstützung für die folgenden von PJRT APIs unterstützten Funktionen nachdenken:
- Erinnerungsbereiche
- Benutzerdefinierte Layouts
- Kommunikationsvorgänge wie Senden/Empfangen
- Host-Auslagerung
- Fragmentierung
Typische Kommunikation zwischen PJRT-Framework und Geräten
Beispiel-Log
Im Folgenden finden Sie ein Log der Methoden, die zum Laden des PJRT-Plug-ins und zum Ausführen von y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
aufgerufen werden. In diesem Fall zeichnen wir die Interaktion von JAX mit dem PJRT-Plug-in für die StableHLO-Referenz auf.
Beispiel für ein Protokoll
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)