Panoramica dell'API PJRT C++ Device

Sfondo

PJRT è l'API Device uniforme che vogliamo aggiungere all'ecosistema ML. La visione a lungo termine è la seguente:

  1. I framework (JAX, TF, ecc.) chiameranno PJRT, che ha implementazioni specifiche per i dispositivi opache ai framework;
  2. Ogni dispositivo si concentra sull'implementazione delle API PJRT e può essere opaco per i framework.

PJRT offre sia un'API C sia un'API C++. Il collegamento a entrambi i livelli è consentito. L'API C++ utilizza classi per astrarre alcuni concetti, ma ha anche legami più stretti con i tipi di dati XLA. Questa pagina è incentrata sull'API C++.

Componenti PJRT

Componenti PJRT

PjRtClient

Riferimento completo alla pagina pjrt_client.h > PjRtClient.

I client gestiscono tutte le comunicazioni tra il dispositivo e il framework e coprono tutto lo stato utilizzato nella comunicazione. Hanno un insieme generico di API per interagire con un plug-in PJRT e possiedono i dispositivi e gli spazi di memoria per un determinato plug-in.

PjRtDevice

Riferimenti completi su pjrt_client.h > PjRtDevice e pjrt_device_description.h

Una classe di dispositivi viene utilizzata per descrivere un singolo dispositivo. Un dispositivo ha una descrizione per identificare il tipo (hash univoco per identificare GPU/CPU/xPU) e la posizione all'interno di una griglia di dispositivi, sia localmente che a livello globale.

I dispositivi conoscono anche gli spazi di memoria associati e il client di cui sono di proprietà.

Un dispositivo non conosce necessariamente i buffer dei dati effettivi associati, ma può scoprirli esaminando gli spazi di memoria associati.

PjRtMemorySpace

Per riferimento completo, visita la pagina pjrt_client.h > PjRtMemorySpace.

Gli spazi di memoria possono essere utilizzati per descrivere una posizione della memoria. Possono essere sganciati e possono essere pubblicati ovunque, ma essere accessibili da un dispositivo oppure possono essere fissati e devono essere pubblicati su un dispositivo specifico.

Gli spazi di memoria conoscono i relativi buffer di dati associati, i dispositivi (plurale) a cui è associato uno spazio di memoria e il client di cui fanno parte.

PjRtBuffer

Per riferimento completo, visita la pagina pjrt_client.h > PjRtBuffer.

Un buffer memorizza i dati su un dispositivo in un formato con cui sarà facile lavorare all'interno del plug-in, ad esempio un attributo elementi MLIR o un formato tensore proprietario. Un framework può provare a inviare dati a un dispositivo sotto forma di xla::Literal, ovvero per un argomento di input al modulo, che deve essere clonato (o preso in prestito) nella memoria del dispositivo. Quando un buffer non è più necessario, il metodo Delete viene richiamato dal framework per la pulizia.

Un buffer conosce lo spazio di memoria di cui fa parte e può capire indirettamente quali dispositivi sono in grado di accedervi, ma i buffer non conoscono necessariamente i propri dispositivi.

Per comunicare con i framework, i buffer sanno come eseguire la conversione da e verso un tipo xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Le API per la creazione di un buffer hanno una semantica del buffer che consente di stabilire se i dati letterali del buffer host possono essere condivisi, copiati o modificati.

Infine, un buffer potrebbe dover durare più a lungo dell'ambito della sua esecuzione, se viene assegnato a una variabile nel livello del framework x = jit(foo)(10). In questi casi, i buffer consentono di creare riferimenti esterni che forniscono un puntatore di proprietà temporanea ai dati memorizzati nel buffer, insieme ai metadati (dimensioni di tipo di dati / dimensioni) per interpretare i dati sottostanti.

PjRtCompiler

Per riferimento completo, visita la pagina pjrt_compiler.h > PjRtCompiler.

La classe PjRtCompiler fornisce dettagli di implementazione utili per i backend XLA, ma non è necessaria per l'implementazione di un plug-in. In teoria, la responsabilità di un PjRtCompiler o del metodo PjRtClient::Compile è prendere un modulo di input e restituire un PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Per riferimenti completi, consulta pjrt_executable.h > PjRtExecutable e pjrt_client.h > PjRtLoadedExecutable.

Un PjRtExecutable sa come prendere un artefatto compilato e le opzioni di esecuzione e serializzarle/deserializzarle in modo che un file eseguibile possa essere memorizzato e caricato come necessario.

PjRtLoadedExecutable è l'eseguibile compilato in memoria pronto per l'esecuzione di argomenti di input. È una sottoclasse di PjRtExecutable.

Gli eseguibili vengono interfacciati tramite uno dei metodi Execute del client:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Prima di chiamare Execute, il framework trasferirà tutti i dati richiesti a Execute di proprietà del client in esecuzione, ma restituiti per il riferimento del framework.PjRtBuffers Questi buffer vengono quindi forniti come argomenti al metodo Execute.

PJRT Concepts

PjRtFutures e calcoli asincroni

Se una parte di un plug-in viene implementata in modo asincrono, deve implementare correttamente i futures.

Considera il seguente programma:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Un plug-in asincrono sarebbe in grado di accodare il calcolo x e restituire immediatamente un buffer non ancora pronto per essere letto, ma verrà completato l'esecuzione. L'esecuzione può continuare ad accodare i calcoli necessari dopo il giorno x, che non richiedono x, inclusa l'esecuzione su altri dispositivi PJRT. Una volta che il valore di x è necessario, l'esecuzione verrà bloccata finché il buffer non si dichiara pronto tramite il futuro restituito da GetReadyFuture.

I futures possono essere utili per determinare quando un oggetto diventa disponibile, inclusi dispositivi e buffer.

Concetti avanzati

Se vai oltre l'implementazione delle API di base, potrai espandere le funzionalità di JAX che possono essere utilizzate da un plug-in. Si tratta di funzionalità che richiedono l'attivazione nel senso che il flusso di lavoro di JIT ed esecuzione tipico funzionerà senza di esse, ma per una pipeline di qualità di produzione è necessario valutare il grado di supporto di queste funzionalità supportate dalle API PJRT:

  • Spazi di memoria
  • Layout personalizzati
  • Operazioni di comunicazione come invio/ricezione
  • Offloading dell'host
  • Partizionamento orizzontale

Comunicazione tipica tra il framework PJRT e il dispositivo

Log di esempio

Di seguito è riportato un log dei metodi chiamati per caricare il plug-in PJRT ed eseguire y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). In questo caso, registriamo l'interazione di JAX con il plug-in PJRT di riferimento StableHLO.

Log di esempio

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)