Présentation de l'API Device C++ de PJRT

Arrière-plan

PJRT est l'API Device uniforme que nous souhaitons ajouter à l'écosystème de ML. À long terme, notre vision est la suivante :

  1. Les frameworks (JAX, TF, etc.) appellent PJRT, qui possède des implémentations spécifiques à l'appareil qui sont opaques aux frameworks.
  2. Chaque appareil se concentre sur l'implémentation des API PJRT et peut être opaque pour les frameworks.

PJRT propose une API C et une API C++. Il est possible de s'insérer à l'une ou l'autre des couches. L'API C++ utilise des classes pour abstraire certains concepts, mais elle est également plus étroitement liée aux types de données XLA. Cette page se concentre sur l'API C++.

Composants PJRT

Composants PJRT

PjRtClient

Référence complète : pjrt_client.h > PjRtClient.

Les clients gèrent toutes les communications entre l'appareil et le framework, et encapsulent tous les états utilisés dans la communication. Ils disposent d'un ensemble générique d'API pour interagir avec un plug-in PJRT et possèdent les appareils et les espaces mémoire pour un plug-in donné.

PjRtDevice

Références complètes sur pjrt_client.h > PjRtDevice et pjrt_device_description.h

Une classe d'appareils est utilisée pour décrire un seul appareil. Un appareil possède une description pour aider à identifier son type (hachage unique pour identifier le GPU/CPU/xPU) et son emplacement dans une grille d'appareils, à la fois localement et globalement.

Les appareils connaissent également les espaces de mémoire associés et le client auquel ils appartiennent.

Un appareil ne connaît pas nécessairement les tampons de données réels qui lui sont associés, mais il peut les identifier en parcourant les espaces de mémoire associés.

PjRtMemorySpace

Référence complète : pjrt_client.h > PjRtMemorySpace.

Les espaces mémoire peuvent être utilisés pour décrire un emplacement de mémoire. Ils peuvent être détachés et être accessibles depuis un appareil, mais être situés n'importe où, ou ils peuvent être épinglés et doivent être situés sur un appareil spécifique.

Les espaces mémoire connaissent leurs tampons de données associés, les appareils (au pluriel) auxquels un espace mémoire est associé, ainsi que le client dont il fait partie.

PjRtBuffer

Référence complète : pjrt_client.h > PjRtBuffer.

Un tampon contient des données sur un appareil dans un format facile à utiliser dans le plug-in, comme un attribut d'élément MLIR ou un format de Tensor propriétaire. Un framework peut essayer d'envoyer des données à un appareil sous la forme d'un xla::Literal, c'est-à-dire pour un argument d'entrée du module, qui doit être cloné (ou emprunté), à la mémoire de l'appareil. Une fois qu'un tampon n'est plus nécessaire, la méthode Delete est appelée par le framework pour le nettoyer.

Un tampon connaît l'espace mémoire dont il fait partie et peut donc déterminer de manière transitive les appareils qui peuvent y accéder, mais les tampons ne connaissent pas nécessairement leurs appareils.

Pour communiquer avec les frameworks, les tampons savent comment effectuer des conversions vers et depuis un type xla::Literal :

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Les API de création d'un tampon ont une sémantique de tampon qui aide à déterminer si les données littérales du tampon hôte peuvent être partagées, copiées ou modifiées.

Enfin, une mémoire tampon peut avoir besoin de durer plus longtemps que la portée de son exécution, si elle est attribuée à une variable dans la couche de framework x = jit(foo)(10). Dans ce cas, les mémoires tampons permettent de créer des références externes qui fournissent un pointeur temporairement détenu vers les données contenues dans la mémoire tampon, ainsi que des métadonnées (dtype / tailles de dimension) pour interpréter les données sous-jacentes.

PjRtCompiler

Référence complète : pjrt_compiler.h > PjRtCompiler.

La classe PjRtCompiler fournit des détails d'implémentation utiles pour les backends XLA, mais n'est pas nécessaire pour qu'un plug-in puisse être implémenté. En théorie, la responsabilité d'un PjRtCompiler ou de la méthode PjRtClient::Compile est de prendre un module d'entrée et de renvoyer un PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Références complètes sur pjrt_executable.h > PjRtExecutable et pjrt_client.h > PjRtLoadedExecutable.

Un PjRtExecutable sait comment prendre un artefact compilé et des options d'exécution, puis les sérialiser/désérialiser afin qu'un exécutable puisse être stocké et chargé selon les besoins.

PjRtLoadedExecutable est l'exécutable compilé en mémoire qui est prêt à exécuter les arguments d'entrée. Il s'agit d'une sous-classe de PjRtExecutable.

Les exécutables sont interfacés à l'aide de l'une des méthodes Execute du client :

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Avant d'appeler Execute, le framework transfère toutes les données requises vers PjRtBuffers appartenant au client exécutant, mais renvoyées pour que le framework puisse les référencer. Ces tampons sont ensuite fournis en tant qu'arguments à la méthode Execute.

Concepts PJRT

Futures et calculs asynchrones

Si une partie d'un plug-in est implémentée de manière asynchrone, elle doit implémenter correctement les futures.

Prenons l'exemple du programme suivant :

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Un plug-in asynchrone pourrait mettre en file d'attente le calcul x et renvoyer immédiatement un tampon qui n'est pas encore prêt à être lu, mais l'exécution le remplira. L'exécution peut continuer à mettre en file d'attente les calculs nécessaires après x, qui ne nécessitent pas x, y compris l'exécution sur d'autres appareils PJRT. Une fois la valeur de x nécessaire, l'exécution sera bloquée jusqu'à ce que le tampon se déclare prêt via le futur renvoyé par GetReadyFuture.

Les futurs peuvent être utiles pour déterminer quand un objet devient disponible, y compris les appareils et les tampons.

Concepts avancés

L'extension au-delà de l'implémentation des API de base permettra d'étendre les fonctionnalités de JAX qui peuvent être utilisées par un plug-in. Il s'agit de fonctionnalités facultatives, dans le sens où le workflow JIT et d'exécution typique fonctionnera sans elles. Toutefois, pour un pipeline de qualité de production, il convient probablement de réfléchir au degré de prise en charge de l'une de ces fonctionnalités prises en charge par les API PJRT :

  • Espaces de mémoire
  • Mise en page personnalisée
  • Opérations de communication telles que l'envoi/la réception
  • Déchargement de l'hôte
  • Segmentation

Communication typique entre le framework PJRT et l'appareil

Exemple de journal

Vous trouverez ci-dessous un journal des méthodes appelées pour charger le plug-in PJRT et exécuter y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Dans ce cas, nous enregistrons l'interaction de JAX avec le plug-in PJRT de référence StableHLO.

Exemple de journal

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)