Descripción general de la API de dispositivos de PJRT C++

Segundo plano

PJRT es la API de Device uniforme que queremos agregar al ecosistema de AA. La visión a largo plazo es la siguiente:

  1. Los frameworks (JAX, TF, etc.) llamarán a PJRT, que tiene implementaciones específicas del dispositivo que son opacas para los frameworks.
  2. Cada dispositivo se enfoca en implementar las APIs de PJRT y puede ser opaco para los marcos de trabajo.

PJRT ofrece una API de C y una de C++. Puedes conectarte en cualquiera de las capas. La API de C++ usa clases para abstraer algunos conceptos, pero también tiene vínculos más sólidos con los tipos de datos de XLA. En esta página, se enfoca en la API de C++.

Componentes de PJRT

Componentes de PJRT

PjRtClient

Consulta la referencia completa en pjrt_client.h > PjRtClient.

Los clientes administran toda la comunicación entre el dispositivo y el framework, y encapsulan todo el estado que se usa en la comunicación. Tienen un conjunto genérico de APIs para interactuar con un complemento de PJRT y son propietarios de los dispositivos y espacios de memoria de un complemento determinado.

PjRtDevice

Referencias completas en pjrt_client.h > PjRtDevice y pjrt_device_description.h

Una clase de dispositivo se usa para describir un solo dispositivo. Un dispositivo tiene una descripción para ayudar a identificar su tipo (hash único para identificar GPU/CPU/xPU) y su ubicación dentro de una cuadrícula de dispositivos, tanto a nivel local como global.

Los dispositivos también conocen sus espacios de memoria asociados y el cliente al que pertenecen.

Un dispositivo no conoce necesariamente los búferes de datos reales asociados con él, pero puede averiguarlo a través de sus espacios de memoria asociados.

PjRtMemorySpace

Consulta la referencia completa en pjrt_client.h > PjRtMemorySpace.

Los espacios de memoria se pueden usar para describir una ubicación de memoria. Se pueden desfijar y pueden estar en cualquier lugar, pero se puede acceder a ellos desde un dispositivo, o bien se pueden fijar y deben estar en un dispositivo específico.

Los espacios de memoria conocen sus búferes de datos asociados y los dispositivos (en plural) con los que está asociado un espacio de memoria, así como el cliente del que forma parte.

PjRtBuffer

Consulta la referencia completa en pjrt_client.h > PjRtBuffer.

Un búfer contiene datos en un dispositivo en algún formato con el que será fácil trabajar dentro del complemento, como un atributo de elementos MLIR o un formato de tensor propietario. Un framework puede intentar enviar datos a un dispositivo en forma de xla::Literal, es decir, para un argumento de entrada al módulo, que se debe clonar (o tomar prestado), a la memoria del dispositivo. Una vez que ya no se necesita un búfer, el framework invoca el método Delete para realizar la limpieza.

Un búfer conoce el espacio de memoria del que forma parte y puede determinar de forma transitiva qué dispositivos pueden acceder a él, pero los búferes no necesariamente conocen sus dispositivos.

Para comunicarse con frameworks, los búferes saben cómo convertir de un tipo xla::Literal y a uno de estos:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Las APIs para crear un búfer tienen semántica de búfer, que ayuda a determinar si los datos literales del búfer del host se pueden compartir, copiar o mutar.

Por último, es posible que un búfer deba durar más que el alcance de su ejecución, si se asigna a una variable en la capa del framework x = jit(foo)(10). En estos casos, los búferes permiten crear referencias externas que proporcionan un puntero de propiedad temporal a los datos que contiene el búfer, junto con metadatos (tamaños de tipo de datos o dimensiones) para interpretar los datos subyacentes.

PjRtCompiler

Consulta la referencia completa en pjrt_compiler.h > PjRtCompiler.

La clase PjRtCompiler proporciona detalles de implementación útiles para los backends de XLA, pero no es necesario que un complemento la implemente. En teoría, la responsabilidad de un PjRtCompiler, o el método PjRtClient::Compile, es tomar un módulo de entrada y mostrar un PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Consulta la referencia completa en pjrt_executable.h > PjRtExecutable y pjrt_client.h > PjRtLoadedExecutable.

Un PjRtExecutable sabe cómo tomar un artefacto compilado y opciones de ejecución y serializarlos o deserializarlos para que un ejecutable se pueda almacenar y cargar según sea necesario.

PjRtLoadedExecutable es el ejecutable compilado en la memoria que está listo para que se ejecuten los argumentos de entrada. Es una subclase de PjRtExecutable.

Los ejecutables se comunican a través de uno de los métodos Execute del cliente:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Antes de llamar a Execute, el framework transferirá todos los datos necesarios a PjRtBuffers, que es propiedad del cliente que realiza la ejecución, pero se muestra para que el framework lo use como referencia. Luego, estos búferes se proporcionan como argumentos al método Execute.

Conceptos de PJRT

PjRtFutures y cálculos asíncronos

Si alguna parte de un complemento se implementa de forma asíncrona, debe implementar correctamente los futuros.

Considera el siguiente programa:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Un complemento asíncrono podría poner en cola el cálculo x y mostrar inmediatamente un búfer que aún no está listo para leerse, pero la ejecución lo propagará. La ejecución puede seguir poniendo en cola los cálculos necesarios después de x, que no requieren x, incluida la ejecución en otros dispositivos PJRT. Una vez que se necesite el valor de x, la ejecución se bloqueará hasta que el búfer se declare listo a través del futuro que devuelve GetReadyFuture.

Los futuros pueden ser útiles para determinar cuándo un objeto estará disponible, incluidos los dispositivos y los búferes.

Conceptos avanzados

Si extiendes la implementación de las APIs básicas, se expandirán las funciones de JAX que puede usar un complemento. Todas estas son funciones de habilitación en el sentido de que, en el JIT y el flujo de trabajo de ejecución típicos, funcionarán sin ellas, pero para una canalización de calidad de producción, es probable que se deba tener en cuenta el grado de compatibilidad con cualquiera de estas funciones compatibles con las APIs de PJRT:

  • Espacios de memoria
  • Diseños personalizados
  • Operaciones de comunicación, como enviar o recibir
  • Descarga del host
  • Fragmentación

Comunicación típica entre el framework y el dispositivo de PJRT

Ejemplo de registro

El siguiente es un registro de los métodos a los que se llamó para cargar el complemento PJRT y ejecutar y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). En este caso, registramos JAX interactuando con el complemento PJRT de referencia de StableHLO.

Registro de ejemplo

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)