Descripción general de la API de dispositivos de PJRT C++

Fondo

PJRT es la API de Device uniforme que queremos agregar al ecosistema de AA. La visión a largo plazo es la siguiente:

  1. Los frameworks (JAX, TF, etc.) llamarán a PJRT, que tiene implementaciones específicas del dispositivo que son opacas para los frameworks.
  2. Cada dispositivo se enfoca en implementar las APIs de PJRT y puede ser opaco para los frameworks.

PJRT ofrece una API de C y una API de C++. Es posible conectar en cualquiera de las capas. La API de C++ usa clases para abstraer algunos conceptos, pero también tiene vínculos más sólidos con los tipos de datos de XLA. Esta página se centra en la API de C++.

Componentes de PJRT

Componentes de PJRT

PjRtClient

Consulta la referencia completa en pjrt_client.h > PjRtClient.

Los clientes administran toda la comunicación entre el dispositivo y el framework, y encapsulan todo el estado que se usa en la comunicación. Tienen un conjunto genérico de APIs para interactuar con un complemento de PJRT y son propietarios de los dispositivos y los espacios de memoria para un complemento determinado.

PjRtDevice

Referencias completas en pjrt_client.h > PjRtDevice y pjrt_device_description.h

Una clase de dispositivo se usa para describir un solo dispositivo. Un dispositivo tiene una descripción para ayudar a identificar su tipo (hash único para identificar la GPU, la CPU o la XPU) y su ubicación dentro de una cuadrícula de dispositivos tanto a nivel local como global.

Los dispositivos también conocen sus espacios de memoria asociados y el cliente al que pertenecen.

Un dispositivo no necesariamente conoce los búferes de datos reales asociados a él, pero puede averiguarlo buscando en sus espacios de memoria asociados.

PjRtMemorySpace

Consulta la referencia completa en pjrt_client.h > PjRtMemorySpace.

Los espacios de memoria se pueden usar para describir una ubicación de memoria. Se pueden desanclar y pueden residir en cualquier lugar, pero deben ser accesibles desde un dispositivo, o bien se pueden anclar y deben residir en un dispositivo específico.

Los espacios de memoria conocen sus búferes de datos asociados y los dispositivos (en plural) con los que se asocia un espacio de memoria, así como el cliente del que forma parte.

PjRtBuffer

Consulta la referencia completa en pjrt_client.h > PjRtBuffer.

Un búfer contiene datos en un dispositivo en algún formato con el que será fácil trabajar dentro del complemento, como un atributo de elementos de MLIR o un formato de tensor propietario. Un framework puede intentar enviar datos a un dispositivo en forma de xla::Literal, es decir, para un argumento de entrada al módulo, que debe clonarse (o tomarse prestado) en la memoria del dispositivo. Una vez que ya no se necesita un búfer, el marco de trabajo invoca el método Delete para limpiarlo.

Un búfer conoce el espacio de memoria del que forma parte y, de forma transitiva, puede determinar qué dispositivos pueden acceder a él, pero los búferes no necesariamente conocen sus dispositivos.

Para comunicarse con los frameworks, los búferes saben cómo convertir un tipo xla::Literal y viceversa:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Las APIs para crear un búfer tienen semántica de búfer que ayuda a determinar si los datos literales del búfer del host se pueden compartir, copiar o mutar.

Por último, es posible que un búfer deba durar más que el alcance de su ejecución si se asigna a una variable en la capa del framework x = jit(foo)(10). En estos casos, los búferes permiten crear referencias externas que proporcionan un puntero de propiedad temporal a los datos que contiene el búfer, junto con metadatos (dtype / tamaños de dimensiones) para interpretar los datos subyacentes.

PjRtCompiler

Consulta la referencia completa en pjrt_compiler.h > PjRtCompiler.

La clase PjRtCompiler proporciona detalles de implementación útiles para los backends de XLA, pero no es necesaria para que un complemento la implemente. En teoría, la responsabilidad de un PjRtCompiler, o el método PjRtClient::Compile, es tomar un módulo de entrada y devolver un PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Consulta la referencia completa en pjrt_executable.h > PjRtExecutable y pjrt_client.h > PjRtLoadedExecutable.

Un PjRtExecutable sabe cómo tomar un artefacto compilado y opciones de ejecución, y serializarlos o deserializarlos para que se pueda almacenar y cargar un ejecutable según sea necesario.

El PjRtLoadedExecutable es el ejecutable compilado en la memoria que está listo para recibir argumentos de entrada y ejecutarse. Es una subclase de PjRtExecutable.

Se interactúa con los ejecutables a través de uno de los métodos Execute del cliente:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Antes de llamar a Execute, el framework transferirá todos los datos requeridos a PjRtBuffers, que es propiedad del cliente que ejecuta la acción, pero se devuelve para que el framework haga referencia a ellos. Luego, estos búferes se proporcionan como argumentos al método Execute.

Conceptos de PJRT

Cálculos asíncronos y futuros

Si alguna parte de un complemento se implementa de forma asíncrona, debe implementar correctamente los futuros.

Considera el siguiente programa:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Un complemento asíncrono podría poner en cola el cálculo x y devolver de inmediato un búfer que aún no esté listo para leerse, pero la ejecución lo completará. La ejecución puede seguir poniendo en cola los cálculos necesarios después de x, que no requieren x, incluida la ejecución en otros dispositivos de PJRT. Una vez que se necesite el valor de x, la ejecución se bloqueará hasta que el búfer se declare listo a través del futuro que devuelve GetReadyFuture.

Los futuros pueden ser útiles para determinar cuándo un objeto estará disponible, incluidos los dispositivos y los búferes.

Conceptos avanzados

Extenderse más allá de la implementación de las APIs básicas ampliará las funciones de JAX que puede usar un complemento. Todas estas son funciones opcionales en el sentido de que un flujo de trabajo típico de JIT y ejecución funcionará sin ellas, pero para una canalización de calidad de producción, probablemente se debería considerar el grado de compatibilidad con cualquiera de estas funciones compatibles con las APIs de PJRT:

  • Espacios de memoria
  • Diseños personalizados
  • Operaciones de comunicación, como enviar o recibir
  • Descarga del host
  • Fragmentación

Comunicación típica entre el dispositivo y el framework de PJRT

Registro de ejemplo

A continuación, se muestra un registro de los métodos llamados para cargar el complemento de PJRT y ejecutar y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). En este caso, registramos la interacción de JAX con el complemento de PJRT de referencia de StableHLO.

Ejemplo de registro

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)