PJRT C++ Device API – Übersicht

Hintergrund

PJRT ist die einheitliche Device API, die wir dem ML-System hinzufügen möchten. Langfristig soll Folgendes erreicht werden:

  1. Frameworks (JAX, TF usw.) rufen PJRT auf, das gerätespezifische Implementierungen hat, die für die Frameworks nicht transparent sind.
  2. Bei jedem Gerät liegt der Schwerpunkt auf der Implementierung von PJRT APIs und es kann für die Frameworks undurchsichtig sein.

PJRT bietet sowohl eine C-API als auch eine C++-API. Die Einbindung in beiden Schichten ist in Ordnung. Die C++ API verwendet Klassen, um einige Konzepte zu abstrahieren, hat aber auch engere Verbindungen zu XLA-Datentypen. Auf dieser Seite geht es um die C++ API.

PJRT-Komponenten

PJRT-Komponenten

PjRtClient

Vollständige Referenz unter pjrt_client.h > PjRtClient

Clients verwalten die gesamte Kommunikation zwischen dem Gerät und dem Framework und kapseln den gesamten in der Kommunikation verwendeten Status ein. Sie haben eine Reihe von generischen APIs für die Interaktion mit einem PJRT-Plug-in und sind für die Geräte und Speicherbereiche eines bestimmten Plug-ins verantwortlich.

PjRtDevice

Vollständige Verweise unter pjrt_client.h > PjRtDevice und pjrt_device_description.h

Eine Geräteklasse wird verwendet, um ein einzelnes Gerät zu beschreiben. Ein Gerät hat eine Gerätebeschreibung, die dazu beiträgt, seine Art zu identifizieren (eindeutiger Hash zur Identifizierung von GPU/CPU/xPU) und seinen Standort innerhalb eines Geräterasters sowohl lokal als auch global.

Geräte kennen auch ihre zugehörigen Speicherbereiche und den zugehörigen Client.

Ein Gerät kennt nicht unbedingt die zugehörigen Buffers mit den tatsächlichen Daten, kann dies aber durch Durchsuchen der zugehörigen Speicherbereiche herausfinden.

PjRtMemorySpace

Vollständige Referenz unter pjrt_client.h > PjRtMemorySpace

Speicherbereiche können verwendet werden, um einen Speicherort zu beschreiben. Sie können entweder losgelöst sein und sich an beliebiger Stelle befinden, aber von einem Gerät aus zugänglich sein, oder angepinnt sein und sich auf einem bestimmten Gerät befinden müssen.

Speicherbereiche kennen ihre zugehörigen Datenbuffer, die Geräte (im Plural), mit denen ein Speicherbereich verknüpft ist, sowie den Client, zu dem er gehört.

PjRtBuffer

Vollständige Referenz unter pjrt_client.h > PjRtBuffer

Ein Puffer enthält Daten auf einem Gerät in einem Format, mit dem im Plug-in leicht gearbeitet werden kann, z. B. ein MLIR-Element-Attribut oder ein proprietäres Tensorformat. Ein Framework kann versuchen, Daten in Form einer xla::Literal an ein Gerät zu senden, d.h. für ein Eingabeargument für das Modul, das geklont (oder ausgeliehen) werden muss, um in den Arbeitsspeicher des Geräts zu gelangen. Sobald ein Puffer nicht mehr benötigt wird, wird die Methode Delete vom Framework aufgerufen, um ihn zu bereinigen.

Ein Puffer kennt den Speicherplatz, zu dem er gehört, und kann indirekt ermitteln, welche Geräte darauf zugreifen können. Puffer kennen jedoch nicht unbedingt ihre Geräte.

Für die Kommunikation mit Frameworks können Buffers zwischen einem xla::Literal-Typ und einem anderen konvertieren:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

APIs zum Erstellen eines Buffers haben Buffer-Semantiken, die festlegen, ob Literaldaten aus dem Host-Puffer freigegeben, kopiert oder verändert werden können.

Schließlich kann ein Puffer länger als der Umfang seiner Ausführung dauern, wenn er einer Variablen in der Framework-Ebene x = jit(foo)(10) zugewiesen ist. In diesen Fällen ermöglichen Puffer externe Verweise, die einen temporären Verweis auf die vom Puffer gehaltenen Daten sowie Metadaten (Datentyp / Dim-Größen) zur Interpretation der zugrunde liegenden Daten bereitstellen.

PjRtCompiler

Vollständige Referenz unter pjrt_compiler.h > PjRtCompiler

Die Klasse PjRtCompiler enthält nützliche Implementierungsdetails für XLA-Backends, ist aber für die Implementierung eines Plug-ins nicht erforderlich. Theoretisch besteht die Aufgabe eines PjRtCompiler oder der PjRtClient::Compile-Methode darin, ein Eingabemodul zu übernehmen und einen PjRtLoadedExecutable zurückzugeben.

PjRtExecutable / PjRtLoadedExecutable

Vollständige Referenz unter pjrt_executable.h > PjRtExecutable und pjrt_client.h > PjRtLoadedExecutable.

Ein PjRtExecutable kann ein kompiliertes Artefakt und Ausführungsoptionen serialisieren/deserialisieren, damit eine ausführbare Datei bei Bedarf gespeichert und geladen werden kann.

PjRtLoadedExecutable ist die in der Arbeitsspeicher kompilierte ausführbare Datei, die für die Ausführung von Eingabeargumenten bereit ist. Sie ist eine Unterklasse von PjRtExecutable.

Die Kommunikation mit ausführbaren Dateien erfolgt über eine der Execute-Methoden des Clients:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Vor dem Aufruf von Execute überträgt das Framework alle erforderlichen Daten an PjRtBuffers, die dem ausführenden Client gehören, aber für das Framework zurückgegeben werden. Diese Puffer werden dann als Argumente an die Execute-Methode übergeben.

PJRT Concepts

PjRtFutures und asynchrone Berechnungen

Wenn ein Teil eines Plug-ins asynchron implementiert ist, müssen Futures ordnungsgemäß implementiert werden.

Betrachten Sie das folgende Programm:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Ein asynchrones Plug-in kann die Berechnung x in die Warteschlange stellen und sofort einen Puffer zurückgeben, der noch nicht zum Lesen bereit ist, aber bei der Ausführung gefüllt wird. Nach x können weiterhin erforderliche Berechnungen in die Warteschlange gestellt werden, für die keine x erforderlich ist, einschließlich der Ausführung auf anderen PJRT-Geräten. Sobald der Wert von x benötigt wird, wird die Ausführung blockiert, bis der Puffer sich über die von GetReadyFuture zurückgegebene Future als bereit erklärt.

Futures können hilfreich sein, um zu ermitteln, wann ein Objekt verfügbar ist, einschließlich Geräten und Puffern.

Erweiterte Konzepte

Wenn Sie über die Implementierung der Basis-APIs hinausgehen, werden die Funktionen von JAX erweitert, die von einem Plug-in verwendet werden können. Diese Funktionen sind optional, da ein typischer JIT- und Ausführungsablauf auch ohne sie funktioniert. Bei einer Pipeline mit Produktionsqualität sollten Sie jedoch über den Grad der Unterstützung für die folgenden von PJRT APIs unterstützten Funktionen nachdenken:

  • Erinnerungsbereiche
  • Benutzerdefinierte Layouts
  • Kommunikationsvorgänge wie Senden/Empfangen
  • Host-Offloading
  • Fragmentierung

Typische PJRT-Gerätekommunikation

Beispiel-Log

Im Folgenden findest du ein Protokoll der Methoden, die zum Laden des PJRT-Plug-ins und Ausführen von y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) aufgerufen wurden. In diesem Fall zeichnen wir die Interaktion von JAX mit dem PJRT-Plug-in für die StableHLO-Referenz auf.

Beispielprotokoll

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)