Présentation de l'API Device C++ de PJRT

Contexte

PJRT est l'API uniforme que nous souhaitons ajouter à l'écosystème de ML. La vision à long terme est la suivante:

  1. Les frameworks (JAX, TF, etc.) appellent PJRT, qui comporte des implémentations spécifiques à l'appareil qui ne sont pas adaptées aux frameworks.
  2. Chaque appareil se concentre sur l'implémentation des API PJRT et peut être opaque pour les frameworks.

PJRT propose à la fois une API C et une API C++. Vous pouvez vous connecter à n'importe quelle couche. L'API C++ utilise des classes pour éliminer certains concepts, mais elle est également plus étroitement liée aux types de données XLA. Cette page se concentre sur l'API C++.

Composants PJRT

Composants PJRT

PjRtClient

Référence complète sur pjrt_client.h > PjRtClient.

Les clients gèrent toutes les communications entre l'appareil et le framework, et encapsulent tous les états utilisés dans la communication. Ils disposent d'un ensemble générique d'API pour interagir avec un plug-in PJRT, et ils sont propriétaires des appareils et des espaces de mémoire d'un plug-in donné.

PjRtDevice

Références complètes sur pjrt_client.h > PjRtDevice et pjrt_device_description.h

Une classe d'appareil sert à décrire un seul appareil. Un appareil dispose d'une description qui permet d'identifier son type (hachage unique pour identifier le GPU/CPU/xPU) et son emplacement dans une grille d'appareils, à la fois localement et globalement.

Les appareils connaissent également leurs espaces de mémoire associés et le client auquel ils appartiennent.

Un appareil ne connaît pas nécessairement les tampons de données réels qui lui sont associés, mais il peut le déterminer en examinant ses espaces de mémoire associés.

PjRtMemorySpace

Référence complète sur pjrt_client.h > PjRtMemorySpace.

Les espaces mémoire peuvent être utilisés pour décrire un emplacement de mémoire. Ils peuvent être épinglés et doivent se trouver sur un appareil spécifique, ou non épinglés et être accessibles depuis un appareil.

Les espaces mémoire connaissent les tampons de données associés, ainsi que les appareils (au pluriel) auxquels un espace mémoire est associé, ainsi que le client auquel il appartient.

PjRtBuffer

Pour plus d'informations, consultez la page pjrt_client.h > PjRtBuffer.

Un tampon contient des données sur un appareil dans un format facile à utiliser dans le plug-in, tel qu'un attribut d'éléments MLIR ou un format de Tensor propriétaire. Un framework peut essayer d'envoyer des données à un appareil sous la forme d'un xla::Literal, c'est-à-dire un argument d'entrée du module, qui doit être cloné (ou emprunté) à la mémoire de l'appareil. Une fois qu'un tampon n'est plus nécessaire, la méthode Delete est appelée par le framework pour le nettoyer.

Un tampon connaît l'espace mémoire dont il fait partie et peut déterminer de manière transitoire quels appareils peuvent y accéder, mais les tampons ne connaissent pas nécessairement leurs appareils.

Pour communiquer avec les frameworks, les tampons savent convertir en type xla::Literal et en type xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Les API permettant de créer un tampon utilisent la sémantique du tampon, qui permet de déterminer si les données littérales du tampon de l'hôte peuvent être partagées, copiées ou modifiées.

Enfin, un tampon peut avoir besoin de durer plus longtemps que le champ d'application de son exécution. S'il est attribué à une variable de la couche de framework x = jit(foo)(10), les tampons permettent dans ce cas de créer des références externes qui fournissent un pointeur temporairement détenu vers les données détenues par le tampon, ainsi que des métadonnées (dtype / dim Size) pour interpréter les données sous-jacentes.

PjRtCompiler

Référence complète sur pjrt_compiler.h > PjRtCompiler.

La classe PjRtCompiler fournit des détails d'implémentation utiles pour les backends XLA, mais n'est pas nécessaire pour l'implémentation d'un plug-in. En théorie, la responsabilité d'un PjRtCompiler, ou de la méthode PjRtClient::Compile, consiste à utiliser un module d'entrée et à renvoyer un PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Référence complète sur pjrt_executable.h > PjRtExecutable et pjrt_client.h > PjRtLoadedExecutable.

Un PjRtExecutable sait comment prendre un artefact compilé et des options d'exécution, et les sérialiser/désérialiser afin qu'un exécutable puisse être stocké et chargé si nécessaire.

PjRtLoadedExecutable est l'exécutable compilé en mémoire qui est prêt à exécuter les arguments d'entrée. Il s'agit d'une sous-classe de PjRtExecutable.

Les fichiers exécutables sont associés via l'une des méthodes Execute du client:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Avant d'appeler Execute, le framework transfère toutes les données requises vers PjRtBuffers, qui appartient au client en cours d'exécution, mais est renvoyé pour que le framework puisse les référencer. Ces tampons sont ensuite fournis en tant qu'arguments à la méthode Execute.

Concepts PJRT

PjRtFutures et calculs asynchrones

Si une partie d'un plug-in est implémentée de manière asynchrone, elle doit implémenter correctement les futures.

Prenons l'exemple de programme suivant:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Un plug-in asynchrone peut mettre en file d'attente le calcul x et renvoyer immédiatement un tampon qui n'est pas encore prêt à être lu, mais l'exécution le remplira. L'exécution peut continuer à mettre en file d'attente les calculs nécessaires après x, qui ne nécessitent pas x, y compris sur d'autres appareils PJRT. Une fois que la valeur de x est nécessaire, l'exécution est bloquée jusqu'à ce que le tampon se déclare prêt via l'objet Future renvoyé par GetReadyFuture.

Les futures peuvent être utiles pour déterminer quand un objet devient disponible, y compris les appareils et les tampons.

Concepts avancés

Si vous allez au-delà de l'implémentation des API de base, vous étendrez les fonctionnalités de JAX pouvant être utilisées par un plug-in. Il s'agit de fonctionnalités activées par défaut, dans le sens où le workflow JIT et d'exécution typiques fonctionneront sans elles. Toutefois, pour un pipeline de qualité de production, il est conseillé de réfléchir au niveau de prise en charge de toutes ces fonctionnalités prises en charge par les API PJRT:

  • Espaces mémoire
  • Mise en page personnalisée
  • Opérations de communication telles que l'envoi/la réception
  • Déchargement de l'hôte
  • Segmentation

Communication typique entre le framework PJRT et l'appareil

Exemple de journal

Vous trouverez ci-dessous un journal des méthodes appelées pour charger le plug-in PJRT et exécuter y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Dans ce cas, nous journalisons JAX qui interagit avec le plug-in PJRT de référence StableHLO.

Exemple de journal

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)