Llamadas personalizadas de XLA

En este documento, se describe cómo escribir y usar llamadas personalizadas de XLA con la biblioteca de XLA de la FFI. Las llamadas personalizadas son un mecanismo para describir una "operación" externa en el módulo HLO al compilador de XLA (en el tiempo de compilación), y la FFI de XLA es un mecanismo para registrar la implementación de esas operaciones con XLA (en el tiempo de ejecución). FFI significa "interfaz de función externa" y es un conjunto de APIs de C que definen una interfaz binaria (ABI) para que XLA llame a código externo escrito en otros lenguajes de programación. XLA proporciona vinculaciones solo de encabezado para la FFI de XLA escritas en C++, las cuales ocultan al usuario final todos los detalles de bajo nivel de las APIs de C subyacentes.

Crea una llamada personalizada en la CPU

Puedes crear una instrucción de HLO que represente una llamada personalizada a través de la API del cliente de XLA. Por ejemplo, en el siguiente código, se usa una llamada personalizada para calcular A[i] = B[i % 128]+ C[i] en la CPU. (Por supuesto que puedes, ¡y deberías! haz esto con HLO común).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Crea una llamada personalizada en la GPU

El registro de llamadas personalizadas de la GPU con la FFI de XLA es casi idéntico. La única diferencia es que, en el caso de la GPU, debes solicitar una transmisión subyacente de la plataforma (transmisión CUDA o ROCM) para poder iniciar el kernel en el dispositivo. A continuación, se muestra un ejemplo de CUDA que realiza el mismo cálculo (A[i] = B[i % 128] + C[i]) que el código de CPU anterior.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Primero, observa que la función de llamada personalizada de GPU sigue siendo una función ejecutada en la CPU. La función do_custom_call de la CPU se encarga de poner el trabajo en cola en la GPU. Aquí, inicia un kernel CUDA, pero también podría realizar otra acción, como llamar a cuBLAS.

Los argumentos y los resultados también se encuentran en el host, y el miembro de datos contiene un puntero a la memoria del dispositivo (es decir, GPU). Los búferes que se pasan al controlador de llamadas personalizado tienen la forma de los búferes de dispositivo subyacentes, por lo que la llamada personalizada puede calcular los parámetros de inicio del kernel a partir de ellos.

Pasa tuplas a llamadas personalizadas

Considera la siguiente llamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Tanto en la CPU como en la GPU, una tupla se representa en la memoria como un array de punteros. Cuando XLA llama a llamadas personalizadas con argumentos o resultados de tuplas, las aplana y las pasa como resultados o argumentos de búfer normales.

Salidas de tuplas como búferes de temperatura

Las entradas de tuplas para las llamadas personalizadas son convenientes, pero no son estrictamente necesarias. Si no se admitieran entradas de tuplas a las llamadas personalizadas, siempre puedes descomprimir las tuplas con get-tuple-element antes de pasarlas a la llamada personalizada.

Por otro lado, los resultados de tuplas te permiten realizar acciones que, de otra manera, no podrías realizar.

El motivo obvio para tener salidas de tupla es que esa salida es la forma en que una llamada personalizada (o cualquier otra op de XLA) muestra varios arreglos independientes.

Pero, menos obvio, el resultado de una tupla también es una forma de agregar tu memoria de temperatura de llamada personalizada. Sí, un resultado puede representar un búfer temporal. Considera que un búfer de salida tiene la propiedad que la op puede escribir en él y puede leer desde allí una vez que se escribe en él. Eso es exactamente lo que quieres de un búfer temporal.

En el ejemplo anterior, supongamos que queremos usar F32[1024] como búfer temporal. Luego, se escribe el HLO como se muestra arriba y nunca se lee el índice de la tupla 1 de la salida de la llamada personalizada.