Contexte
PJRT est l'API Device uniforme que nous souhaitons ajouter à l'écosystème ML. La vision à long terme est la suivante:
- Les frameworks (JAX, TF, etc.) appellent PJRT, qui dispose d'implémentations spécifiques à l'appareil qui sont opaques pour les frameworks.
- Chaque appareil se concentre sur l'implémentation des API PJRT et peut être opaque pour les frameworks.
PJRT propose à la fois une API C et une API C++. Vous pouvez vous connecter à n'importe quelle couche. L'API C++ utilise des classes pour éliminer certains concepts, mais elle est également plus étroitement liée aux types de données XLA. Cette page se concentre sur l'API C++.
Composants PJRT
PjRtClient
Référence complète sur pjrt_client.h > PjRtClient
.
Les clients gèrent toutes les communications entre l'appareil et le framework, et encapsulent tous les états utilisés dans la communication. Ils disposent d'un ensemble générique d'API pour interagir avec un plug-in PJRT, et ils sont propriétaires des appareils et des espaces de mémoire d'un plug-in donné.
PjRtDevice
Références complètes sur pjrt_client.h > PjRtDevice
et pjrt_device_description.h
Une classe d'appareil permet de décrire un seul appareil. Un appareil dispose d'une description qui permet d'identifier son type (hachage unique pour identifier le GPU/CPU/xPU) et son emplacement dans une grille d'appareils, à la fois localement et globalement.
Les appareils connaissent également leurs espaces de mémoire associés et le client auquel ils appartiennent.
Un appareil ne connaît pas nécessairement les tampons de données réels qui lui sont associés, mais il peut le déterminer en examinant ses espaces de mémoire associés.
PjRtMemorySpace
Référence complète sur pjrt_client.h > PjRtMemorySpace
.
Les espaces mémoire peuvent être utilisés pour décrire un emplacement de mémoire. Ils peuvent être épinglés et doivent se trouver sur un appareil spécifique, ou non épinglés et être accessibles depuis un appareil.
Les espaces mémoire connaissent les tampons de données associés, ainsi que les appareils (au pluriel) auxquels un espace mémoire est associé, ainsi que le client auquel il appartient.
PjRtBuffer
Référence complète sur pjrt_client.h > PjRtBuffer
.
Un tampon contient des données sur un appareil dans un format facile à utiliser dans le plug-in, comme un attribut d'éléments MLIR ou un format de tenseur propriétaire.
Un framework peut essayer d'envoyer des données à un appareil sous la forme d'un xla::Literal
, c'est-à-dire pour un argument d'entrée au module, qui doit être cloné (ou emprunté), à la mémoire de l'appareil. Une fois qu'un tampon n'est plus nécessaire, la méthode Delete
est appelée par le framework pour le nettoyer.
Un tampon connaît l'espace mémoire dont il fait partie et peut déterminer de manière transitoire quels appareils peuvent y accéder, mais les tampons ne connaissent pas nécessairement leurs appareils.
Pour communiquer avec les frameworks, les tampons savent convertir en type xla::Literal
et en type xla::Literal
:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Les API de création d'un tampon disposent de sémantiques de tampon qui permettent de déterminer si les données littérales du tampon hôte peuvent être partagées, copiées ou modifiées.
Enfin, un tampon peut avoir besoin de durer plus longtemps que le champ d'application de son exécution, s'il est attribué à une variable dans la couche de framework x = jit(foo)(10)
. Dans ce cas, les tampons permettent de créer des références externes qui fournissent un pointeur temporairement détenu vers les données détenues par le tampon, ainsi que des métadonnées (taille de type / dimension) pour interpréter les données sous-jacentes.
PjRtCompiler
Référence complète sur pjrt_compiler.h > PjRtCompiler
.
La classe PjRtCompiler
fournit des détails d'implémentation utiles pour les backends XLA, mais n'est pas nécessaire pour l'implémentation d'un plug-in. En théorie, la responsabilité d'un PjRtCompiler
, ou de la méthode PjRtClient::Compile
, est de prendre un module d'entrée et de renvoyer un PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
Référence complète sur pjrt_executable.h > PjRtExecutable
et pjrt_client.h > PjRtLoadedExecutable
.
Un PjRtExecutable
sait comment prendre un artefact compilé et des options d'exécution, et les sérialiser/désérialiser afin qu'un exécutable puisse être stocké et chargé si nécessaire.
PjRtLoadedExecutable
est l'exécutable compilé en mémoire qui est prêt à exécuter les arguments d'entrée. Il s'agit d'une sous-classe de PjRtExecutable
.
Les exécutables sont connectés via l'une des méthodes Execute
du client:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Avant d'appeler Execute
, le framework transfère toutes les données requises vers PjRtBuffers
, qui appartient au client exécutant, mais est renvoyé pour référence par le framework. Ces tampons sont ensuite fournis en tant qu'arguments à la méthode Execute
.
Concepts PJRT
PjRtFutures et calculs asynchrones
Si une partie d'un plug-in est implémentée de manière asynchrone, elle doit implémenter correctement les futures.
Prenons l'exemple de programme suivant:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Un plug-in asynchrone peut mettre en file d'attente le calcul x
et renvoyer immédiatement un tampon qui n'est pas encore prêt à être lu, mais l'exécution le remplira. L'exécution peut continuer à mettre en file d'attente les calculs nécessaires après x
, qui ne nécessitent pas x
, y compris l'exécution sur d'autres appareils PJRT. Une fois que la valeur de x
est requise, l'exécution est bloquée jusqu'à ce que le tampon se déclare prêt via le futur renvoyé par GetReadyFuture
.
Les futures peuvent être utiles pour déterminer quand un objet devient disponible, y compris les appareils et les tampons.
Concepts avancés
Si vous allez au-delà de l'implémentation des API de base, vous étendrez les fonctionnalités de JAX pouvant être utilisées par un plug-in. Il s'agit de fonctionnalités activées par défaut, dans le sens où le workflow JIT et d'exécution typiques fonctionneront sans elles. Toutefois, pour un pipeline de qualité de production, il est conseillé de réfléchir au niveau de prise en charge de toutes ces fonctionnalités prises en charge par les API PJRT:
- Espaces mémoire
- Mise en page personnalisée
- Opérations de communication telles que l'envoi/la réception
- Déchargement de l'hôte
- Segmentation
Communication typique entre le framework PJRT et l'appareil
Exemple de journal
Vous trouverez ci-dessous un journal des méthodes appelées pour charger le plug-in PJRT et exécuter y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. Dans ce cas, nous journalisons JAX qui interagit avec le plug-in PJRT de référence StableHLO.
Exemple de journal
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)