PJRT C++ Device API – Übersicht

Hintergrund

PJRT ist die einheitliche Geräte-API, die wir dem ML-Ökosystem hinzufügen möchten. Die langfristige Vision ist:

  1. Frameworks (JAX, TF usw.) rufen PJRT auf, das gerätespezifische Implementierungen hat, die für die Frameworks nicht sichtbar sind.
  2. Jedes Gerät konzentriert sich auf die Implementierung von PJRT-APIs und kann für die Frameworks undurchsichtig sein.

PJRT bietet sowohl eine C- als auch eine C++-API. Es ist in Ordnung, die API in einer der beiden Schichten zu verwenden. Die C++-API verwendet Klassen, um einige Konzepte zu abstrahieren, ist aber auch stärker an XLA-Datentypen gebunden. Auf dieser Seite geht es um die C++ API.

PJRT-Komponenten

PJRT-Komponenten

PjRtClient

Vollständige Referenz unter pjrt_client.h > PjRtClient.

Clients verwalten die gesamte Kommunikation zwischen dem Gerät und dem Framework und kapseln den gesamten in der Kommunikation verwendeten Status. Sie haben einen generischen Satz von APIs für die Interaktion mit einem PJRT-Plug-in und sind für die Geräte und Speicherbereiche für ein bestimmtes Plug-in verantwortlich.

PjRtDevice

Vollständige Referenzen unter pjrt_client.h > PjRtDevice und pjrt_device_description.h

Eine Geräteklasse wird verwendet, um ein einzelnes Gerät zu beschreiben. Ein Gerät hat eine Gerätebeschreibung, die hilft, seinen Typ zu identifizieren (eindeutiger Hash zur Identifizierung von GPU/CPU/xPU) und seinen Standort in einem Raster von Geräten sowohl lokal als auch global zu bestimmen.

Geräte kennen auch die zugehörigen Speicherbereiche und den Client, dem sie gehören.

Ein Gerät kennt nicht unbedingt die Puffer der tatsächlichen Daten, die ihm zugeordnet sind, kann dies aber herausfinden, indem es die zugehörigen Speicherbereiche durchsucht.

PjRtMemorySpace

Vollständige Referenz unter pjrt_client.h > PjRtMemorySpace.

Speicherbereiche können verwendet werden, um einen Speicherort zu beschreiben. Sie können entweder nicht angepinnt sein und sich an einem beliebigen Ort befinden, aber über ein Gerät zugänglich sein, oder sie können angepinnt sein und müssen sich auf einem bestimmten Gerät befinden.

Speicherbereiche kennen die zugehörigen Datenpuffer und die Geräte, mit denen ein Speicherbereich verknüpft ist, sowie den Client, zu dem er gehört.

PjRtBuffer

Vollständige Referenz unter pjrt_client.h > PjRtBuffer.

In einem Puffer werden Daten auf einem Gerät in einem Format gespeichert, das sich im Plug-in einfach verarbeiten lässt, z. B. ein MLIR-Elementattribut oder ein proprietäres Tensorformat. Ein Framework kann versuchen, Daten in Form eines xla::Literal an ein Gerät zu senden, d.h. für ein Eingabeargument für das Modul, das in den Speicher des Geräts geklont (oder ausgeliehen) werden muss. Wenn ein Puffer nicht mehr benötigt wird, wird die Methode Delete vom Framework aufgerufen, um ihn zu bereinigen.

Ein Puffer kennt den Speicherbereich, zu dem er gehört, und kann transitiv herausfinden, welche Geräte darauf zugreifen können. Puffer kennen jedoch nicht unbedingt ihre Geräte.

Für die Kommunikation mit Frameworks wissen Puffer, wie sie in den Typ xla::Literal konvertiert werden und umgekehrt:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

APIs zum Erstellen eines Puffers haben Puffersemantik, die festlegt, ob Literaldaten aus dem Hostpuffer freigegeben, kopiert oder geändert werden können.

Schließlich muss ein Puffer möglicherweise länger als der Umfang seiner Ausführung dauern, wenn er einer Variablen in der Framework-Ebene x = jit(foo)(10) zugewiesen wird. In diesen Fällen ermöglichen Puffer das Erstellen externer Referenzen, die einen temporär besessenen Zeiger auf die vom Puffer gehaltenen Daten zusammen mit Metadaten (dtype / dim-Größen) zur Interpretation der zugrunde liegenden Daten bereitstellen.

PjRtCompiler

Vollständige Referenz unter pjrt_compiler.h > PjRtCompiler.

Die Klasse PjRtCompiler enthält nützliche Implementierungsdetails für XLA-Back-Ends, ist aber für die Implementierung eines Plug-ins nicht erforderlich. Theoretisch besteht die Aufgabe eines PjRtCompiler oder der Methode PjRtClient::Compile darin, ein Eingabemodul zu verwenden und ein PjRtLoadedExecutable zurückzugeben.

PjRtExecutable / PjRtLoadedExecutable

Vollständige Referenz unter pjrt_executable.h > PjRtExecutable und pjrt_client.h > PjRtLoadedExecutable.

Eine PjRtExecutable kann ein kompiliertes Artefakt und Ausführungsoptionen serialisieren/deserialisieren, sodass eine ausführbare Datei nach Bedarf gespeichert und geladen werden kann.

PjRtLoadedExecutable ist die im Arbeitsspeicher kompilierte ausführbare Datei, die für die Ausführung von Eingabeargumenten bereit ist. Sie ist eine Unterklasse von PjRtExecutable.

Die Kommunikation mit den ausführbaren Dateien erfolgt über eine der Execute-Methoden des Clients:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Bevor Execute aufgerufen wird, überträgt das Framework alle erforderlichen Daten an PjRtBuffers, die dem ausführenden Client gehören, aber für das Framework zurückgegeben werden, damit es darauf verweisen kann. Diese Puffer werden dann als Argumente für die Execute-Methode bereitgestellt.

PJRT-Konzepte

Futures und asynchrone Berechnungen

Wenn ein Teil eines Plug-ins asynchron implementiert wird, muss er Futures richtig implementieren.

Sehen Sie sich das folgende Programm an:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Ein asynchrones Plug-in könnte die Berechnung x in die Warteschlange stellen und sofort einen Puffer zurückgeben, der noch nicht gelesen werden kann, aber durch die Ausführung gefüllt wird. Die Ausführung kann nach x weiterhin erforderliche Berechnungen in die Warteschlange stellen, für die x nicht erforderlich ist, einschließlich der Ausführung auf anderen PJRT-Geräten. Sobald der Wert von x benötigt wird, wird die Ausführung blockiert, bis der Puffer sich über den von GetReadyFuture zurückgegebenen Future als bereit erklärt.

Futures können nützlich sein, um zu ermitteln, wann ein Objekt verfügbar wird, einschließlich Geräten und Puffern.

Erweiterte Konzepte

Wenn Sie über die Implementierung der Basis-APIs hinausgehen, können Sie die Funktionen von JAX erweitern, die von einem Plug-in verwendet werden können. Das sind alles Opt-in-Funktionen. Ein typischer JIT- und Ausführungsablauf funktioniert auch ohne sie. Für eine Pipeline in Produktionsqualität sollte jedoch überlegt werden, inwieweit die von PJRT-APIs unterstützten Funktionen unterstützt werden sollen:

  • Arbeitsspeicherbereiche
  • Benutzerdefinierte Layouts
  • Kommunikationsvorgänge wie Senden/Empfangen
  • Host-Offloading
  • Fragmentierung

Typische PJRT-Framework-Gerätekommunikation

Beispiellog

Im Folgenden sehen Sie ein Log der Methoden, die zum Laden des PJRT-Plug-ins und zum Ausführen von y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) aufgerufen werden. In diesem Fall protokollieren wir die Interaktion von JAX mit dem StableHLO-Referenz-PJRT-Plug-in.

Beispiellog

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)