PJRT C++ Device API Overview

Background

PJRT is the uniform Device API that we want to add to the ML ecosystem. The long term vision is that:

  1. Frameworks (JAX, TF, etc.) will call PJRT, which has device-specific implementations that are opaque to the frameworks;
  2. Each device focuses on implementing PJRT APIs, and can be opaque to the frameworks.

PJRT offers both a C API and C++ API. Plugging in at either layer is OK, the C++ API uses classes to abstract away some concepts, but also has stronger ties to XLA datatypes. This page focuses on the C++ API.

PJRT Components

PJRT Components

PjRtClient

Full reference at pjrt_client.h > PjRtClient.

Clients manage all communication between the device and framework, and encapsulate all state used in the communication. They have a generic set of APIs for interacting with a PJRT plugin, and they own the devices and memory spaces for a given plugin.

PjRtDevice

Full references at pjrt_client.h > PjRtDevice, and pjrt_device_description.h

A device class is used to describe a single device. A device has a device description to help identify its kind (unique hash to identify GPU/CPU/xPU), and location within a grid of devices both locally and globally.

Devices also know their associated memory spaces and the client it is owned by.

A device does not necessarily know the buffers of actual data associated with it, but it can figure that out by looking through its associated memory spaces.

PjRtMemorySpace

Full reference at pjrt_client.h > PjRtMemorySpace.

Memory spaces can be used to describe a location of memory. These can either be unpinned, and are free to live anywhere but be accessible from a device, or they can be pinned and must live on a specific device.

Memory spaces know their associated buffers of data, and the devices (plural) that a memory space is associated with, as well as the client it is a part of.

PjRtBuffer

Full reference at pjrt_client.h > PjRtBuffer.

A buffer holds data on a device in some format that will be easy to work with inside the plugin, such as an MLIR elements attr or a proprietary tensor format. A framework may try to send data to a device in the form of an xla::Literal, i.e. for an input argument to the module, which must be cloned (or borrowed), to the device's memory. Once a buffer is no longer needed the Delete method is invoked by the framework to clean up.

A buffer knows the memory space it is a part of, and transitively can figure out which devices are able to access it, but buffers don't necessarily know their devices.

For communicating with frameworks, buffers know how to convert to and from an xla::Literal type:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

APIs for creating a buffer have Buffer Semantics which help dictate if literal data from the host buffer can be shared or copied or mutated.

Lastly, a buffer may need last longer than the scope of its execution, if it is assigned to a variable in the framework layer x = jit(foo)(10), in these cases buffers allow building external references which provide a temporarily owned pointer to the data held by the buffer, along with metadata (dtype / dim sizes) for interpreting the underlying data.

PjRtCompiler

Full reference at pjrt_compiler.h > PjRtCompiler.

The PjRtCompiler class provides useful implementation details for XLA backends, but is not necessary for a plugin to implement. In theory, the responsibility of a PjRtCompiler, or the PjRtClient::Compile method, is to take an input module and return a PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Full reference at pjrt_executable.h > PjRtExecutable, and pjrt_client.h > PjRtLoadedExecutable.

A PjRtExecutable knows how to take a compiled artifact and execution options and serialize/deserialize them so an executable can be stored and loaded as needed.

The PjRtLoadedExecutable is the in-memory compiled executable which is ready for input arguments to execute, it is a subclass of PjRtExecutable.

Executables are interfaced with via one of the client's Execute methods:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Before calling Execute the framework will transfer all required data to PjRtBuffers owned by the executing client, but returned for the framework to reference. These buffers are then provided as arguments to the Execute method.

PJRT Concepts

PjRtFutures & Async Computations

If any part of a plugin is implemented asynchronously, it must properly implement futures.

Consider the following program:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

An async plugin would be able to enqueue the computation x, and immediately return a buffer which isn't ready to be read yet, but execution will populate it. Execution can continue to enqueue necessary computations after x, that don't require x, including execution on other PJRT devices. Once the value of x is needed, execution will block until the buffer declares itself ready via the future returned by GetReadyFuture.

Futures can be useful to determine when an object becomes available, including devices and buffers.

Advanced concepts

Extending beyond implementing the base APIs will expand the features of JAX that can be used by a plugin. These are all opt-in features in the sense that at typical JIT and execute workflow will work without them, but for a production quality pipeline some thought should likely be put into the degree of support for any of these features supported by PJRT APIs:

  • Memory spaces
  • Custom layouts
  • Communication ops like send/recv
  • Host offloading
  • Sharding

Typical PJRT framework-device communication

Example Log

The following is a log of the methods called to load the PJRT plugin and execute y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). In this case we log JAX interacting with the StableHLO Reference PJRT plugin.

Example log

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)