Sfondo
PJRT è l'API Device uniforme che vogliamo aggiungere all'ecosistema ML. La visione a lungo termine è la seguente:
- I framework (JAX, TF e così via) chiameranno PJRT, che ha implementazioni specifiche del dispositivo opache per i framework;
- Ogni dispositivo si concentra sull'implementazione delle API PJRT e può essere opaco per i framework.
PJRT offre un'API C e un'API C++. L'inserimento in uno dei due livelli è accettabile. L'API C++ utilizza classi per astrarre alcuni concetti, ma ha anche legami più stretti con i tipi di dati XLA. Questa pagina è incentrata sull'API C++.
PJRT Components
PjRtClient
Riferimento completo all'indirizzo pjrt_client.h > PjRtClient.
I client gestiscono tutte le comunicazioni tra il dispositivo e il framework e incapsulano tutto lo stato utilizzato nella comunicazione. Hanno un insieme generico di API per interagire con un plug-in PJRT e possiedono i dispositivi e gli spazi di memoria per un determinato plug-in.
PjRtDevice
Riferimenti completi su pjrt_client.h > PjRtDevice,
e pjrt_device_description.h
Una classe di dispositivi viene utilizzata per descrivere un singolo dispositivo. Un dispositivo ha una descrizione per aiutare a identificarne il tipo (hash univoco per identificare GPU/CPU/xPU) e la posizione all'interno di una griglia di dispositivi sia a livello locale che globale.
I dispositivi conoscono anche gli spazi di memoria associati e il client a cui appartengono.
Un dispositivo non conosce necessariamente i buffer dei dati effettivi associati, ma può scoprirlo esaminando gli spazi di memoria associati.
PjRtMemorySpace
Riferimento completo all'indirizzo pjrt_client.h > PjRtMemorySpace.
Gli spazi di memoria possono essere utilizzati per descrivere una posizione della memoria. Questi possono essere sganciati e possono essere posizionati ovunque, ma sono accessibili da un dispositivo, oppure possono essere bloccati e devono essere posizionati su un dispositivo specifico.
Gli spazi di memoria conoscono i buffer di dati associati e i dispositivi (plurale) a cui sono associati, nonché il client di cui fanno parte.
PjRtBuffer
Riferimento completo all'indirizzo pjrt_client.h > PjRtBuffer.
Un buffer contiene dati su un dispositivo in un formato facile da utilizzare
all'interno del plug-in, ad esempio un attributo di elementi MLIR o un formato tensore proprietario.
Un framework potrebbe tentare di inviare dati a un dispositivo sotto forma di xla::Literal,
ovvero per un argomento di input al modulo, che deve essere clonato (o preso in prestito) nella
memoria del dispositivo. Quando un buffer non è più necessario, il framework richiama il metodo Delete per eseguire la pulizia.
Un buffer conosce lo spazio di memoria di cui fa parte e può capire in modo transitivo quali dispositivi possono accedervi, ma i buffer non conoscono necessariamente i propri dispositivi.
Per comunicare con i framework, i buffer sanno come eseguire la conversione da e verso un tipo xla::Literal:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Le API per la creazione di un buffer hanno semantica del buffer che aiuta a stabilire se i dati letterali del buffer host possono essere condivisi, copiati o modificati.
Infine, un buffer potrebbe dover durare più a lungo dell'ambito della sua esecuzione se viene assegnato a una variabile nel livello del framework x = jit(foo)(10). In questi casi, i buffer consentono di creare riferimenti esterni che forniscono un puntatore temporaneamente di proprietà ai dati contenuti nel buffer, insieme ai metadati (dtype / dimensioni delle dimensioni) per interpretare i dati sottostanti.
PjRtCompiler
Riferimento completo all'indirizzo pjrt_compiler.h > PjRtCompiler.
La classe PjRtCompiler fornisce dettagli di implementazione utili per i backend XLA, ma non è necessaria per l'implementazione di un plug-in. In teoria, la
responsabilità di un PjRtCompiler o del metodo PjRtClient::Compile è
prendere un modulo di input e restituire un PjRtLoadedExecutable.
PjRtExecutable / PjRtLoadedExecutable
Riferimento completo all'indirizzo pjrt_executable.h > PjRtExecutable,
e pjrt_client.h > PjRtLoadedExecutable.
Un PjRtExecutable sa come prendere un artefatto compilato e le opzioni di esecuzione
e serializzarli/deserializzarli in modo che un eseguibile possa essere archiviato e caricato in base
alle necessità.
PjRtLoadedExecutable è l'eseguibile compilato in memoria pronto
per l'esecuzione degli argomenti di input. Si tratta di una sottoclasse di PjRtExecutable.
Gli eseguibili vengono interfacciati tramite uno dei metodi Execute del client:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Prima di chiamare Execute, il framework trasferirà tutti i dati richiesti a
PjRtBuffers di proprietà del client di esecuzione, ma restituiti per il riferimento del framework. Questi buffer vengono quindi forniti come argomenti al metodo Execute.
PJRT Concepts
Futures e calcoli asincroni
Se una parte di un plug-in viene implementata in modo asincrono, deve implementare correttamente i futuri.
Considera il seguente programma:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Un plug-in asincrono sarebbe in grado di mettere in coda il calcolo x e restituire immediatamente un buffer non ancora pronto per la lettura, ma l'esecuzione lo riempirà. L'esecuzione può continuare a mettere in coda i calcoli necessari dopo x, che
non richiedono x, inclusa l'esecuzione su altri dispositivi PJRT. Una volta che è necessario il valore di
x, l'esecuzione verrà bloccata finché il buffer non si dichiara pronto tramite
il futuro restituito da GetReadyFuture.
Le future possono essere utili per determinare quando un oggetto diventa disponibile, inclusi dispositivi e buffer.
Concetti avanzati
L'estensione oltre l'implementazione delle API di base amplierà le funzionalità di JAX che possono essere utilizzate da un plug-in. Si tratta di funzionalità di attivazione, nel senso che il flusso di lavoro tipico di JIT e di esecuzione funzionerà senza, ma per una pipeline di qualità di produzione è necessario riflettere sul grado di supporto per una qualsiasi di queste funzionalità supportate dalle API PJRT:
- Spazi di memoria
- Layout personalizzati
- Operazioni di comunicazione come invio/ricezione
- Offload dell'host
- Sharding
Comunicazione tipica tra framework PJRT e dispositivo
Log di esempio
Di seguito è riportato un log dei metodi chiamati per caricare il plug-in PJRT ed eseguire y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). In questo caso,
registriamo l'interazione di JAX con il plug-in PJRT di riferimento StableHLO.
Log di esempio
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)