Llamadas personalizadas de XLA

En este documento, se describe cómo escribir y usar llamadas personalizadas de XLA con la biblioteca de FFI de XLA. La llamada personalizada es un mecanismo para describir una "operación" externa en el módulo HLO para el compilador de XLA (en tiempo de compilación), y la FFI de XLA es un mecanismo para registrar la implementación de tales operaciones con XLA (en tiempo de ejecución). FFI significa "interfaz de función externa" y es un conjunto de APIs de C que definen una interfaz binaria (ABI) para que XLA llame a código externo escrito en otros lenguajes de programación. XLA proporciona vinculaciones solo de encabezado para la FFI de XLA escrita en C++, que oculta todos los detalles de bajo nivel de las APIs de C subyacentes del usuario final.

Llamadas personalizadas de JAX + XLA

Consulta la documentación de JAX para ver ejemplos de extremo a extremo de la integración de llamadas personalizadas y la FFI de XLA con JAX.

Vinculación de FFI de XLA

La vinculación de FFI de XLA es una especificación en tiempo de compilación de la firma de la llamada personalizada: argumentos de la llamada personalizada, atributos y sus tipos, y parámetros adicionales que se pasan a través del contexto de ejecución (es decir, la transmisión de GPU para el backend de GPU). La vinculación de FFI de XLA se puede vincular a cualquier elemento C++ invocable (puntero de función, lambda, etcétera) con una firma operator() compatible. El controlador construido decodifica el marco de llamada de FFI de XLA (definido por la API de C estable), verifica el tipo de todos los parámetros y reenvía los resultados decodificados a la devolución de llamada definida por el usuario.

La vinculación de FFI de XLA se basa en gran medida en la metaprogramación de plantillas para poder compilar el controlador construido en el código máquina más eficiente. Los gastos generales de tiempo de ejecución son del orden de unos pocos nanosegundos para cada parámetro de llamada personalizado.

Puntos de personalización de FFI de XLA implementados como especializaciones de plantillas, y los usuarios pueden definir cómo decodificar sus tipos personalizados, es decir, es posible definir una decodificación personalizada para los tipos enum class definidos por el usuario.

Cómo devolver errores desde llamadas personalizadas

Las implementaciones de llamadas personalizadas deben devolver el valor xla::ffi::Error para indicar éxito o error al tiempo de ejecución de XLA. Es similar a absl::Status y tiene el mismo conjunto de códigos de error. No usamos absl::Status porque no tiene una ABI estable y sería inseguro pasarlo entre la biblioteca de llamadas personalizadas cargada de forma dinámica y XLA.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Almacenamiento en búfer de argumentos y resultados

XLA usa el estilo de paso de destino para los resultados: las llamadas personalizadas (o cualquier otra operación de XLA) no asignan memoria para los resultados, sino que escriben en los destinos que pasa el tiempo de ejecución de XLA. XLA usa la asignación de búferes estáticos y asigna búferes para todos los valores según sus rangos activos en el tiempo de compilación.

Los resultados que se pasan a los controladores de FFI se encapsulan en una plantilla Result<T>, que tiene una semántica similar a la de un puntero: operator-> da acceso al parámetro subyacente.

Los argumentos y los resultados de AnyBuffer permiten acceder a parámetros de búfer de llamadas personalizados de cualquier tipo de datos. Esto es útil cuando la llamada personalizada tiene una implementación genérica que funciona para varios tipos de datos y la implementación de la llamada personalizada realiza el envío en tiempo de ejecución según el tipo de datos. AnyBuffer brinda acceso al tipo de datos del búfer, las dimensiones y un puntero al búfer en sí.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumentos y resultados del búfer restringido

Buffer permite agregar restricciones sobre el tipo de datos del búfer y la cantidad de dimensiones, y el controlador las verificará automáticamente y devolverá un error al tiempo de ejecución de XLA si los argumentos del tiempo de ejecución no coinciden con la firma del controlador de FFI.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumentos y resultados variádicos

Si la cantidad de argumentos y el resultado pueden ser diferentes en distintas instancias de una llamada personalizada, se pueden decodificar en el tiempo de ejecución con RemainingArgs y RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Los argumentos y resultados variádicos se pueden declarar después de los argumentos y resultados normales. Sin embargo, es ilegal vincular argumentos y resultados normales después de uno variádico.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atributos

La FFI de XLA admite la decodificación automática de mlir::DictionaryAttr que se pasa como un custom_call backend_config en los argumentos del controlador de FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

En este ejemplo, la llamada personalizada tiene un solo argumento de búfer y dos atributos, y la FFI de XLA puede decodificarlos automáticamente y pasarlos al elemento llamable definido por el usuario.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atributos de enumeración definidos por el usuario

La FFI de XLA puede decodificar automáticamente los atributos integrales de MLIR en enumeraciones definidas por el usuario. La clase enum debe tener el mismo tipo integral subyacente, y la decodificación debe registrarse de forma explícita con XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Vincula todos los atributos de llamada personalizados

Es posible acceder a todos los atributos de llamada personalizados como un diccionario y decodificar de forma diferida solo los atributos que se necesitan en el tiempo de ejecución.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atributos de Struct definidos por el usuario

La FFI de XLA puede decodificar atributos de diccionario en structs definidos por el usuario.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

En el ejemplo anterior, range es un atributo mlir::DictionaryAttr y, en lugar de acceder a los campos del diccionario por nombre, se puede decodificar automáticamente como una struct de C++. La decodificación debe registrarse de forma explícita con una macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (en segundo plano, define una especialización de plantilla en el espacio de nombres ::xla::ffi, por lo que la macro debe agregarse al espacio de nombres global).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Los atributos personalizados se pueden cargar desde un diccionario, al igual que cualquier otro atributo. En el siguiente ejemplo, todos los atributos de llamada personalizados se decodifican como un Dictionary, y se puede acceder a un range por su nombre.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Crea una llamada personalizada en la CPU

Puedes crear una instrucción de HLO que represente una llamada personalizada a través de la API del cliente de XLA. Por ejemplo, el siguiente código usa una llamada personalizada para calcular A[i] = B[i % 128]+ C[i] en la CPU. (Por supuesto que puedes y debes hacerlo. – Haz esto con el HLO habitual.

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Crea una llamada personalizada en la GPU

El registro de llamadas personalizadas de GPU con XLA FFI es casi idéntico. La única diferencia es que, para la GPU, debes solicitar una transmisión de plataforma subyacente (transmisión de CUDA o ROCM) para poder iniciar el kernel en el dispositivo. A continuación, se muestra un ejemplo de CUDA que realiza el mismo cálculo (A[i] = B[i % 128] + C[i]) que el código de CPU anterior.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Primero, ten en cuenta que la función de llamada personalizada de la GPU sigue siendo una función que se ejecuta en la CPU. La función de CPU do_custom_call es responsable de poner en cola el trabajo en la GPU. Aquí se inicia un kernel de CUDA, pero también se podría hacer otra cosa, como llamar a cuBLAS.

Los argumentos y los resultados también se encuentran en el host, y el miembro de datos contiene un puntero a la memoria del dispositivo (es decir, la GPU). Los búferes que se pasan al controlador de llamadas personalizado tienen la forma de los búferes del dispositivo subyacente, por lo que la llamada personalizada puede calcular los parámetros de inicio del kernel a partir de ellos.

Cómo pasar tuplas a llamadas personalizadas

Considera la siguiente llamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Tanto en la CPU como en la GPU, una tupla se representa en la memoria como un array de punteros. Cuando XLA llama a llamadas personalizadas con argumentos o resultados de tuplas, los aplana y los pasa como argumentos o resultados de búfer normales.

Tuplas de salida como búferes temporales

Las tuplas de entrada para las llamadas personalizadas son convenientes, pero no son estrictamente necesarias. Si no admitiéramos entradas de tuplas para las llamadas personalizadas, siempre podrías desempaquetar las tuplas con get-tuple-element antes de pasarlas a la llamada personalizada.

Por otro lado, las salidas de tuplas te permiten hacer cosas que no podrías hacer de otra manera.

La razón obvia para tener salidas de tuplas es que así es como una llamada personalizada (o cualquier otra operación de XLA) devuelve varios arrays independientes.

Sin embargo, de forma menos obvia, una salida de tupla también es una forma de darle memoria temporal a tu llamada personalizada. Sí, un output puede representar un búfer temporal. Considera que un búfer de salida tiene la propiedad de que la operación puede escribir en él y leerlo después de que se haya escrito en él. Eso es exactamente lo que quieres de un búfer temporal.

En el ejemplo anterior, supongamos que queremos usar F32[1024] como un búfer temporal. Luego, escribiríamos el HLO como se indicó anteriormente y simplemente nunca leeríamos el índice de tupla 1 de la salida de la llamada personalizada.