背景
PJRT は、ML エコシステムに追加する統一されたデバイス API です。長期的なビジョンは次のとおりです。
- フレームワーク(JAX、TF など)は、フレームワークに対して不透明なデバイス固有の実装を持つ PJRT を呼び出します。
- 各デバイスは PJRT API の実装に重点を置いており、フレームワークに対して不透明にすることができます。
PJRT は、C API と C++ API の両方を提供します。どちらのレイヤにプラグインしても問題ありません。C++ API はクラスを使用して一部のコンセプトを抽象化しますが、XLA データ型との結びつきも強くなっています。このページでは、C++ API について説明します。
PJRT コンポーネント
PjRtClient
完全なリファレンスは pjrt_client.h > PjRtClient をご覧ください。
クライアントはデバイスとフレームワーク間のすべての通信を管理し、通信で使用されるすべての状態をカプセル化します。PJRT プラグインとやり取りするための汎用 API セットを備え、特定のプラグインのデバイスとメモリ空間を所有します。
PjRtDevice
完全なリファレンスは pjrt_client.h > PjRtDevice と pjrt_device_description.h をご覧ください
デバイスクラスは、単一のデバイスを記述するために使用されます。デバイスには、デバイスの種類(GPU/CPU/xPU を識別する一意のハッシュ)と、デバイスのグリッド内のローカルおよびグローバルな位置を特定するためのデバイスの説明があります。
デバイスは、関連付けられているメモリ空間と、所有しているクライアントも認識しています。
デバイスは、それに関連付けられた実際のデータのバッファを必ずしも認識しているわけではありませんが、関連付けられたメモリ空間を調べることで、それを把握できます。
PjRtMemorySpace
完全なリファレンスは pjrt_client.h > PjRtMemorySpace をご覧ください。
メモリ空間は、メモリの場所を記述するために使用できます。これらは、ピン留めを解除して、デバイスからアクセスできる任意の場所に保存することも、ピン留めして特定のデバイスに保存することもできます。
メモリ空間は、関連付けられたデータバッファ、メモリ空間が関連付けられているデバイス(複数)、およびそれが属するクライアントを認識します。
PjRtBuffer
完全なリファレンスは pjrt_client.h > PjRtBuffer をご覧ください。
バッファは、MLIR 要素の属性や独自のテンソル形式など、プラグイン内で扱いやすい形式でデバイス上のデータを保持します。フレームワークは、xla::Literal の形式でデバイスにデータを送信しようとする場合があります。つまり、モジュールの入力引数として、デバイスのメモリにクローン(または借用)する必要があります。バッファが不要になると、フレームワークによって Delete メソッドが呼び出され、クリーンアップが行われます。
バッファは、それが属するメモリ空間を認識し、どのデバイスがそれにアクセスできるかを推移的に把握できますが、バッファが必ずしもデバイスを認識しているとは限りません。
フレームワークとの通信では、バッファは xla::Literal 型との間で変換する方法を認識しています。
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
バッファを作成する API には、ホストバッファからのリテラル データを共有、コピー、変更できるかどうかを判断するのに役立つバッファ セマンティクスがあります。
最後に、フレームワーク レイヤ x = jit(foo)(10) の変数に割り当てられている場合、バッファは実行のスコープよりも長く存続する必要がある場合があります。このような場合、バッファは、バッファが保持するデータへの一時的に所有されるポインタと、基盤となるデータを解釈するためのメタデータ(dtype / dim サイズ)を提供する外部参照を構築できます。
PjRtCompiler
完全なリファレンスは pjrt_compiler.h > PjRtCompiler をご覧ください。
PjRtCompiler クラスは XLA バックエンドに役立つ実装の詳細を提供しますが、プラグインが実装する必要はありません。理論上、PjRtCompiler または PjRtClient::Compile メソッドの役割は、入力モジュールを受け取って PjRtLoadedExecutable を返すことです。
PjRtExecutable / PjRtLoadedExecutable
完全なリファレンスは、pjrt_executable.h > PjRtExecutable と pjrt_client.h > PjRtLoadedExecutable をご覧ください。
PjRtExecutable は、コンパイルされたアーティファクトと実行オプションを取得してシリアル化/逆シリアル化する方法を認識しているため、実行可能ファイルを保存して必要に応じて読み込むことができます。
PjRtLoadedExecutable は、実行する入力引数の準備が整ったメモリ内コンパイル済み実行可能ファイルです。これは PjRtExecutable のサブクラスです。
実行可能ファイルは、クライアントの Execute メソッドのいずれかを介してインターフェースされます。
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Execute を呼び出す前に、フレームワークは必要なすべてのデータを実行中のクライアントが所有する PjRtBuffers に転送しますが、フレームワークが参照するために返します。これらのバッファは、Execute メソッドの引数として提供されます。
PJRT のコンセプト
Future と非同期計算
プラグインの一部が非同期で実装されている場合、フューチャーを適切に実装する必要があります。
次のプログラムを考えてみましょう。
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
非同期プラグインは、計算 x をキューに登録し、まだ読み取りの準備ができていないバッファをすぐに返しますが、実行によってバッファが入力されます。実行は、x を必要としない必要な計算(他の PJRT デバイスでの実行など)を x の後に引き続きキューに追加できます。x の値が必要になると、GetReadyFuture から返された Future を介してバッファが準備完了を宣言するまで、実行はブロックされます。
Future は、デバイスやバッファなどのオブジェクトがいつ使用可能になるかを判断するのに役立ちます。
高度なコンセプト
ベース API の実装を超えて拡張すると、プラグインで使用できる JAX の機能が拡張されます。これらはすべてオプトイン機能です。通常の JIT と実行のワークフローはこれらの機能なしで動作しますが、本番環境品質のパイプラインでは、PJRT API でサポートされているこれらの機能のサポートの程度について検討する必要があります。
- メモリ空間
- カスタム レイアウト
- 送信/受信などの通信オペレーション
- ホスト オフロード
- シャーディング
一般的な PJRT フレームワークとデバイスの通信
ログの例
以下は、PJRT プラグインを読み込んで y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) を実行するために呼び出されたメソッドのログです。この例では、StableHLO Reference PJRT プラグインとやり取りする JAX をロギングします。
ログの例
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)