PJRT C++ Device API の概要

背景

PJRT は、ML エコシステムに追加する統一された Device API です。長期的なビジョンは次のとおりです。

  1. フレームワーク(JAX、TF など)は PJRT を呼び出します。PJRT には、フレームワークに対して不透明なデバイス固有の実装があります。
  2. 各デバイスは PJRT API の実装に重点を置いており、フレームワークに対して不透明になる可能性があります。

PJRT には、C API と C++ API の両方が用意されています。どちらのレイヤにプラグインしてもかまいません。C++ API はクラスを使用して一部のコンセプトを抽象化しますが、XLA データ型との関連性も強くなります。このページでは、C++ API について説明します。

PJRT コンポーネント

PJRT コンポーネント

PjRtClient

詳細なリファレンスは pjrt_client.h > PjRtClient をご覧ください。

クライアントは、デバイスとフレームワーク間のすべての通信を管理し、通信で使用されるすべての状態をカプセル化します。PJRT プラグインとやり取りするための汎用 API セットがあり、特定のプラグインのデバイスとメモリ空間を所有します。

PjRtDevice

完全な参照は pjrt_client.h > PjRtDevicepjrt_device_description.h をご覧ください

デバイスクラスは、単一のデバイスを記述するために使用されます。デバイスには、デバイスの種類(GPU/CPU/xPU を識別する一意のハッシュ)と、ローカルとグローバルの両方でデバイスのグリッド内の位置を特定するためのデバイスの説明があります。

また、デバイスは、関連付けられたメモリ空間と、そのメモリ空間を所有するクライアントも認識します。

デバイスは、それに関連付けられている実際のデータのバッファを必ずしも認識しているわけではありませんが、関連するメモリ空間を調べることで、そのバッファを把握できます。

PjRtMemorySpace

詳細なリファレンスは pjrt_client.h > PjRtMemorySpace をご覧ください。

メモリ領域を使用して、メモリの場所を記述できます。これらは、固定を解除して任意の場所に配置し、デバイスからアクセスすることも、固定して特定のデバイスに配置することもできます。

メモリ空間は、関連付けられたデータのバッファ、メモリ空間が関連付けられているデバイス(複数)、およびその一部であるクライアントを認識します。

PjRtBuffer

詳細なリファレンスは pjrt_client.h > PjRtBuffer をご覧ください。

バッファは、MLIR 要素属性や独自のテンソル形式など、プラグイン内で簡単に操作できる形式でデバイス上のデータを保持します。フレームワークは、xla::Literal の形式でデバイスにデータを送信しようとすることがあります。つまり、モジュールへの入力引数として、デバイスのメモリにクローンを作成(または借用)する必要があります。バッファが不要になると、フレームワークによって Delete メソッドが呼び出され、クリーンアップされます。

バッファは、自分が属するメモリ空間を認識し、そのメモリ空間にアクセスできるデバイスを推移的に把握できますが、バッファがデバイスを認識するとは限りません。

フレームワークとの通信では、バッファは xla::Literal 型との間で変換する方法を認識しています。

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

バッファを作成するための API にはバッファ セマンティクスがあり、ホストバッファのリテラル データを共有、コピー、変更できるかどうかを指定できます。

最後に、バッファがフレームワーク レイヤ x = jit(foo)(10) の変数に割り当てられている場合、バッファは実行スコープよりも長く存続する必要がある場合があります。このような場合、バッファは、バッファによって保持されるデータへの一時的に所有されるポインタと、基盤となるデータを解釈するためのメタデータ(データ型 / サイズ)を提供する外部参照を構築できます。

PjRtCompiler

詳細なリファレンスは pjrt_compiler.h > PjRtCompiler をご覧ください。

PjRtCompiler クラスは XLA バックエンドの実装に役立ちますが、プラグインの実装に必要ではありません。理論的には、PjRtCompiler または PjRtClient::Compile メソッドの役割は、入力モジュールを受け取って PjRtLoadedExecutable を返すことです。

PjRtExecutable / PjRtLoadedExecutable

詳細なリファレンスは pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable をご覧ください。

PjRtExecutable は、コンパイルされたアーティファクトと実行オプションを受け取り、シリアル化/シリアル化解除する方法を知っているので、必要に応じて実行可能ファイルを保存して読み込むことができます。

PjRtLoadedExecutable は、入力引数を実行する準備ができているメモリ内コンパイル実行可能ファイルであり、PjRtExecutable のサブクラスです。

実行可能ファイルは、クライアントの Execute メソッドのいずれかを通じてインターフェースされます。

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute を呼び出す前に、フレームワークは必要なすべてのデータを、実行中のクライアントが所有する PjRtBuffers に転送しますが、フレームワークが参照できるように返します。これらのバッファは、Execute メソッドの引数として指定されます。

PJRT のコンセプト

PjRtFutures と非同期計算

プラグインの一部が非同期で実装されている場合は、適切にフューチャーを実装する必要があります。

次のプログラムについて考えてみましょう。

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

非同期プラグインは、計算 x をキューに登録し、まだ読み取りの準備ができていないバッファをすぐに返すことができます。このバッファは実行時に入力されます。実行は、x 後に、x を必要としない必要な計算をキューに追加できます(他の PJRT デバイスでの実行など)。x の値が必要な場合、バッファが GetReadyFuture によって返されたフューチャーを介して準備完了を宣言するまで、実行はブロックされます。

デバイスやバッファなど、オブジェクトが使用可能になるタイミングを判断する場合に、フューチャーが役立ちます。

高度なコンセプト

ベース API の実装を超えて拡張すると、プラグインで使用できる JAX の機能が拡張されます。これらはすべてオプトイン機能です。つまり、一般的な JIT と実行ワークフローは、これらの機能がなくても動作します。ただし、本番環境品質のパイプラインでは、PJRT API でサポートされているこれらの機能のサポート度について検討する必要があります。

  • メモリ空間
  • カスタム レイアウト
  • 送信/受信などの通信オペレーション
  • ホストへのオフロード
  • シャーディング

一般的な PJRT フレームワークとデバイスの通信

ログの例

以下は、PJRT プラグインを読み込んで y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) を実行するために呼び出されたメソッドのロギングです。この場合、StableHLO リファレンス PJRT プラグインとやり取りする JAX をロギングします。

ログの例

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)