Visão geral da API PJRT C++ para dispositivos

Contexto

PJRT é a API Device uniforme que queremos adicionar ao ecossistema de ML. A visão de longo prazo é a seguinte:

  1. Frameworks (JAX, TF etc.) vão chamar o PJRT, que tem implementações específicas para dispositivos que são opacas para os frameworks.
  2. Cada dispositivo se concentra na implementação de APIs PJRT e pode ser opaco para os frameworks.

O PJRT oferece uma API C e uma API C++. A conexão em qualquer camada é aceitável. A API C++ usa classes para abstrair alguns conceitos, mas também tem vínculos mais fortes com os tipos de dados da XLA. Esta página se concentra na API C++.

Componentes PJRT

Componentes PJRT

PjRtClient

Referência completa em pjrt_client.h > PjRtClient.

Os clientes gerenciam toda a comunicação entre o dispositivo e a estrutura e encapsulam todo o estado usado na comunicação. Eles têm um conjunto genérico de APIs para interagir com um plug-in PJRT e são proprietários dos dispositivos e espaços de memória de um determinado plug-in.

PjRtDevice

Referências completas em pjrt_client.h > PjRtDevice e pjrt_device_description.h

Uma classe de dispositivo é usada para descrever um único dispositivo. Um dispositivo tem uma descrição para ajudar a identificar o tipo (hash exclusivo para identificar GPU/CPU/xPU) e o local dentro de uma grade de dispositivos, local e globalmente.

Os dispositivos também conhecem os espaços de memória associados e o cliente a que pertencem.

Um dispositivo não necessariamente conhece os buffers de dados reais associados a ele, mas pode descobrir isso procurando nos espaços de memória associados.

PjRtMemorySpace

Referência completa em pjrt_client.h > PjRtMemorySpace.

Os espaços de memória podem ser usados para descrever um local de memória. Eles podem ser desfixados e ficar disponíveis em qualquer lugar, mas acessíveis em um dispositivo, ou podem ser fixados e precisam estar em um dispositivo específico.

Os espaços de memória sabem quais são os buffers de dados associados e os dispositivos (plural) com que um espaço de memória está associado, além do cliente ao qual ele pertence.

PjRtBuffer

Referência completa em pjrt_client.h > PjRtBuffer.

Um buffer armazena dados em um dispositivo em algum formato fácil de trabalhar no plug-in, como um atributo de elementos MLIR ou um formato de tensor proprietário. Um framework pode tentar enviar dados para um dispositivo na forma de xla::Literal, ou seja, para um argumento de entrada do módulo, que precisa ser clonado (ou emprestado), para a memória do dispositivo. Quando um buffer não é mais necessário, o método Delete é invocado pelo framework para limpar.

Um buffer sabe o espaço de memória de que faz parte e pode descobrir quais dispositivos podem acessá-lo, mas os buffers não necessariamente conhecem os dispositivos.

Para se comunicar com os frameworks, os buffers sabem como converter para e de um tipo xla::Literal:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

As APIs para criação de um buffer têm semântica de buffer, que ajudam a determinar se os dados literais do buffer do host podem ser compartilhados, copiados ou modificados.

Por fim, um buffer pode durar mais do que o escopo da execução, se for atribuído a uma variável na camada do framework x = jit(foo)(10). Nesses casos, os buffers permitem criar referências externas que fornecem um ponteiro de propriedade temporária para os dados mantidos pelo buffer, junto com metadados (tipos de dados / tamanhos de dimensão) para interpretar os dados subjacentes.

PjRtCompiler

Referência completa em pjrt_compiler.h > PjRtCompiler.

A classe PjRtCompiler fornece detalhes de implementação úteis para back-ends XLA, mas não é necessária para a implementação de um plug-in. Em teoria, a responsabilidade de um PjRtCompiler, ou do método PjRtClient::Compile, é receber um módulo de entrada e retornar um PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Referência completa em pjrt_executable.h > PjRtExecutable e pjrt_client.h > PjRtLoadedExecutable.

Um PjRtExecutable sabe como pegar um artefato compilado e opções de execução e serializar/desserializar para que um executável possa ser armazenado e carregado conforme necessário.

O PjRtLoadedExecutable é o executável compilado na memória que está pronto para executar argumentos de entrada. Ele é uma subclasse de PjRtExecutable.

Os executáveis são conectados por um dos métodos Execute do cliente:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Antes de chamar Execute, o framework vai transferir todos os dados necessários para PjRtBuffers, que é de propriedade do cliente de execução, mas retornado para que o framework se refira a ele. Esses buffers são fornecidos como argumentos para o método Execute.

Conceitos de PJRT

PjRtFutures e computações assíncronas

Se qualquer parte de um plug-in for implementada de forma assíncrona, ele precisa implementar corretamente os futuros.

Considere o seguinte programa:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Um plug-in assíncrono pode enfileirar a x de computação e retornar imediatamente um buffer que ainda não está pronto para leitura, mas que será preenchido pela execução. A execução pode continuar a enfileirar as computações necessárias após x, que não exigem x, incluindo a execução em outros dispositivos PJRT. Quando o valor de x for necessário, a execução será bloqueada até que o buffer se declare pronto pelo futuro retornado por GetReadyFuture.

Os futuros podem ser úteis para determinar quando um objeto fica disponível, incluindo dispositivos e buffers.

Conceitos avançados

A extensão além da implementação das APIs básicas vai ampliar os recursos do JAX que podem ser usados por um plug-in. Todos esses recursos são opcionais, no sentido de que o fluxo de trabalho de execução e JIT típico vai funcionar sem eles, mas para um pipeline de qualidade de produção, é necessário pensar no nível de suporte para qualquer um desses recursos com suporte às APIs PJRT:

  • Espaços de memória
  • Layouts personalizados
  • Operações de comunicação, como enviar/receber
  • Desativação do host
  • Fragmentação

Comunicação típica de framework-dispositivo do PJRT

Exemplo de registro

A seguir, há um registro dos métodos chamados para carregar o plug-in PJRT e executar y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). Neste caso, registramos o JAX interagindo com o plug-in PJRT de referência da StableHLO.

Exemplo de registro

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)