PJRT C++ Device API 概览

背景

PJRT 是我们希望添加到机器学习生态系统中的统一设备 API。长期愿景是:

  1. 框架(JAX、TF 等)将调用 PJRT,PJRT 具有对框架不透明的设备专用实现;
  2. 每台设备都专注于实现 PJRT API,并且对框架而言是不可见的。

PJRT 同时提供 C API 和 C++ API。您可以在任一层插入,C++ API 使用类来提取一些概念,但与 XLA 数据类型的联系更紧密。本页重点介绍 C++ API。

PJRT 组件

PJRT 组件

PjRtClient

完整参考文档参见 pjrt_client.h > PjRtClient

客户端管理设备与框架之间的所有通信,并封装通信中使用的所有状态。它们拥有一组通用 API 来与 PJRT 插件交互,并且拥有给定插件的设备和内存空间。

PjRtDevice

完整参考文档:pjrt_client.h > PjRtDevicepjrt_device_description.h

设备类用于描述单个设备。设备具有设备说明,可帮助识别其类型(用于识别 GPU/CPU/xPU 的唯一哈希值),以及在本地和全球设备网格中的相应位置。

设备还知道其关联的内存空间以及其所属的客户端。

设备一定知道与其关联的实际数据的缓冲区,但可以通过查看其关联的内存空间来确定这一点。

PjRtMemorySpace

完整参考文档参见 pjrt_client.h > PjRtMemorySpace

内存空间可用于描述内存的位置。这些文件可以取消固定,并可自由存储在任何位置,但必须可通过设备访问;也可以固定,并必须存储在特定设备上。

内存空间知道其关联的数据缓冲区,以及内存空间关联的设备(复数),以及它所属的客户端。

PjRtBuffer

完整参考文档参见 pjrt_client.h > PjRtBuffer

缓冲区以易于在插件内处理的格式(例如 MLIR 元素属性或专有张量格式)在设备上存储数据。框架可能会尝试以 xla::Literal 的形式将数据发送到设备,即将模块的输入参数克隆(或借用)到设备的内存。当不再需要缓冲区时,框架会调用 Delete 方法进行清理。

缓冲区知道它所属的内存空间,并且可以通过传递方式找出哪些设备可以访问它,但缓冲区不一定知道其设备。

为了与框架通信,缓冲区知道如何与 xla::Literal 类型进行转换:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

用于创建缓冲区的 API 具有缓冲区语义,可帮助确定主机缓冲区中的字面量数据能否共享、复制或更改。

最后,如果缓冲区分配给框架层 x = jit(foo)(10) 中的变量,则可能需要持续的时间比其执行范围更长。在这些情况下,缓冲区允许构建外部引用,这些引用会提供指向缓冲区所持数据的临时所有权指针,以及用于解读底层数据的元数据(dtype / dim 大小)。

PjRtCompiler

完整参考文档参见 pjrt_compiler.h > PjRtCompiler

PjRtCompiler 类可为 XLA 后端提供实用的实现细节,但插件无需实现该类。从理论上讲,PjRtCompilerPjRtClient::Compile 方法的职责是接受输入模块并返回 PjRtLoadedExecutable

PjRtExecutable / PjRtLoadedExecutable

如需查看完整参考文档,请参阅 pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable

PjRtExecutable 知道如何获取已编译工件和执行选项并对其进行序列化/反序列化,以便根据需要存储和加载可执行文件。

PjRtLoadedExecutable 是内存中已编译的可执行文件,可供输入参数执行,它是 PjRtExecutable 的子类。

可执行文件通过客户端的某个 Execute 方法进行接口:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

在调用 Execute 之前,框架会将所有必需数据传输到执行客户端拥有的 PjRtBuffers,但会返回以供框架引用。然后,这些缓冲区会作为参数提供给 Execute 方法。

PJRT 概念

PjRtFuture 和异步计算

如果插件的任何部分是异步实现的,则必须正确实现 Future。

请考虑以下程序:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

异步插件能够将计算 x 加入队列,并立即返回一个尚未准备好读取的缓冲区,但执行将填充该缓冲区。在 x 之后,执行可以继续将不需要 x 的必要计算加入队列,包括在其他 PJRT 设备上执行。需要 x 的值后,执行将被阻塞,直到缓冲区通过 GetReadyFuture 返回的 Future 声明自己已就绪。

在确定对象(包括设备和缓冲区)何时可用时,Future 会很有用。

高级概念

除了实现基本 API 之外,还可以扩展可供插件使用的 JAX 功能。这些都是可选功能,也就是说,在典型的 JIT 和执行工作流中,即使没有这些功能,也能正常运行,但对于生产质量流水线,可能需要考虑对 PJRT API 支持的任何这些功能的支持程度:

  • 内存空间
  • 自定义布局
  • 发送/接收等通信操作
  • 主机分流
  • 分片

典型的 PJRT 框架-设备通信

日志示例

以下是用于加载 PJRT 插件并执行 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) 的调用方法的日志。在本例中,我们会记录 JAX 与 StableHLO 参考 PJRT 插件交互的情况。

日志示例

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)