Llamadas personalizadas de XLA

En este documento, se describe cómo escribir y usar llamadas personalizadas de XLA con una FFI de XLA biblioteca. La llamada personalizada es un mecanismo para describir una “operación” externa en la el módulo HLO al compilador XLA (en el tiempo de compilación), y la FFI de XLA es un mecanismo para registrar la implementación de dichas operaciones con XLA (en el entorno de ejecución). Posiciones de la FFI sobre la “interfaz de función externa” y es un conjunto de APIs de C que define un objeto binario (ABI) para que XLA llame a un código externo escrito en otro lenguaje de programación idiomas. XLA proporciona vinculaciones de solo encabezado para la FFI de XLA escritas en C++, que oculta todos los detalles de bajo nivel de las APIs de C subyacentes al usuario final.

Llamadas personalizadas de JAX + XLA

Consulta la documentación de JAX para ver ejemplos de extremo a extremo de integración de llamadas personalizadas y FFI de XLA con JAX.

Vinculación de FFI de XLA

La vinculación de FFI de XLA es una especificación del tiempo de compilación de la firma de llamada personalizada: argumentos de llamada personalizados, atributos y sus tipos, y parámetros adicionales que se pasan a través del contexto de ejecución (es decir, flujo de GPU para el backend de GPU). El hallazgo de la FFI de XLA se puede vincular a cualquier elemento que se pueda llamar en C++ (puntero de función, lambda, etc.) con una firma operator() compatible. El controlador construido decodifica la llamada de XLA FFI fotograma (definido por la API estable de C), verificación de tipo de todos los parámetros y reenviar los resultados decodificados a la devolución de llamada definida por el usuario.

La vinculación de FFI de XLA depende en gran medida de la metaprogramación de plantillas para poder y compilar controladores en el código máquina más eficiente. Tiempo de ejecución sobrecargas están en orden de un par de nanosegundos para cada llamada parámetro.

Puntos de personalización de la FFI de XLA implementados como especializaciones de plantillas, y los usuarios pueden definir cómo decodificar sus tipos personalizados, es decir, es posible definir la decodificación personalizada para los tipos enum class definidos por el usuario.

Cómo mostrar errores de llamadas personalizadas

Las implementaciones de llamadas personalizadas deben mostrar el valor xla::ffi::Error para indicar el éxito o el error al entorno de ejecución de XLA. Es similar a absl::Status y tiene el mismo conjunto de códigos de error. No usamos absl::Status porque lo hace. tienen una ABI estable y no sería seguro pasarla de forma dinámica una biblioteca de llamadas personalizada y XLA.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumentos y resultados del búfer

XLA usa el estilo de transferencia de destino para los resultados: llamadas personalizadas (o cualquier otro XLA). para las operaciones) no asignan memoria para los resultados y, en su lugar, en los destinos pasados por el entorno de ejecución de XLA. XLA usa un búfer estático asignación y asigna búferes para todos los valores según sus rangos en vivo en tiempo de compilación.

Los resultados pasados a controladores de FFI unidos a una plantilla Result<T>, que tiene una semántica similar a un puntero: operator-> brinda acceso al parámetro.

Los argumentos y resultados de AnyBuffer proporcionan acceso a los parámetros del búfer de llamadas personalizados de cualquier tipo de datos. Esto es útil cuando la llamada personalizada tiene una implementación genérica que funciona para varios tipos de datos, y la implementación de llamadas personalizadas se ejecuta el despacho de productos según el tipo de datos. AnyBuffer otorga acceso a los datos del búfer. las dimensiones y un puntero al búfer en sí.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumentos y resultados de búfer restringidos

Buffer permite agregar restricciones en el tipo de datos y la clasificación del búfer, y el controlador verificará automáticamente y devolverá un error al entorno de ejecución de XLA, si los argumentos de tiempo de ejecución no coinciden con la firma del controlador de la FFI.

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumentos y resultados variadic

Si la cantidad de argumentos y resultados puede ser diferente en diferentes instancias de una llamada personalizada, se pueden decodificar en el entorno de ejecución con RemainingArgs y RemainingRets

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Los argumentos y resultados variadic se pueden declarar después de los argumentos y resultados normales, pero vincular argumentos y resultados normales después de los variadic es ilegal.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atributos

La FFI de XLA admite la decodificación automática de mlir::DictionaryAttr que se pasa como un custom_call backend_config a los argumentos del controlador de FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

En este ejemplo, la llamada personalizada tiene un solo argumento de búfer y dos atributos, y la FFI de XLA puede decodificarlos automáticamente y pasarlos a la función de llamada definida por el usuario.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atributos de enumeración definidos por el usuario

Las FFI de XLA pueden decodificar automáticamente atributos integrales de MLIR en los atributos enums. La clase enum debe tener el mismo tipo de integral subyacente y la decodificación debe estar registrado explícitamente en la FFI de XLA.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Vincula todos los atributos de llamada personalizados

Es posible obtener acceso a todos los atributos de llamada personalizados como un diccionario. y decodificar de forma diferida solo los atributos necesarios en el tiempo de ejecución.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atributos struct definidos por el usuario

La FFI de XLA puede decodificar atributos del diccionario en structs definidos por el usuario.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

En el ejemplo anterior, range es un atributo mlir::DictionaryAttr y, en lugar de acceder a los campos del diccionario por nombre, se puede decodificar automáticamente como una estructura C++. La decodificación se debe registrar de forma explícita con una macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (en segundo plano, define una especialización de plantilla en el espacio de nombres ::xla::ffi, por lo que se debe agregar la macro al espacio de nombres global).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
                                             StructMember<int64_t>("i64"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Los atributos personalizados se pueden cargar desde un diccionario, al igual que cualquier otro atributo. En el siguiente ejemplo, todos los atributos de llamada personalizados se decodifican como un Dictionary y se puede acceder a un range por nombre.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Crea una llamada personalizada en la CPU

Puedes crear una instrucción HLO que represente una llamada personalizada a través del cliente de XLA. API de gcloud. Por ejemplo, el siguiente código usa una llamada personalizada para calcular A[i] = B[i % 128]+ C[i] en la CPU. (Por supuesto que podrías... y debes. – Haz esto con HLO normal).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Crea una llamada personalizada en la GPU

El registro de llamadas personalizadas de la GPU con la FFI de XLA es casi idéntico. La única diferencia es que, para la GPU, debes solicitar una transmisión de plataforma subyacente (transmisión de CUDA o ROCM) para poder iniciar el kernel en el dispositivo. Aquí tienes una cuenta de CUDA ejemplo que realiza el mismo procesamiento (A[i] = B[i % 128] + C[i]) que la CPU el código anterior.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Observa primero que la función de llamada personalizada de GPU sigue siendo una función ejecutada en la CPU. La función de CPU do_custom_call se encarga de poner el trabajo en cola. en la GPU. Aquí lanza un kernel CUDA, pero también podría hacer otra cosa, como llamar CUBLAS.

Los argumentos y los resultados también se alojan en el host, y el miembro de datos contiene un puntero. a la memoria del dispositivo (es decir, la GPU). Los búferes que se pasan al controlador de llamadas personalizado tienen la forma de los búferes subyacentes del dispositivo, por lo que la llamada personalizada puede calcular los parámetros de inicio del kernel a partir de ellos.

Cómo pasar tuplas a llamadas personalizadas

Considera la siguiente llamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

En la CPU y la GPU, una tupla se representa en la memoria como un array de punteros. Cuando XLA llama a las llamadas personalizadas con argumentos o resultados de tupla, las aplana y pasa como argumentos o resultados del búfer regulares.

Salidas de tupla como búferes de temperatura

Las entradas de tupla a las llamadas personalizadas son convenientes, pero no son estrictamente necesarias. Si no admitiéramos entradas de tupla en las llamadas personalizadas, siempre podrías desempaquetar las tuplas con el elemento get-tuple antes de pasarlas al llamada.

Por otro lado, los outputs de tupla te permiten hacer cosas que de otra forma no podrías.

La razón obvia para tener salidas de tupla es que las salidas de tupla son la forma en que una llamada personalizada (o cualquier otra operación de XLA) muestra varios arrays independientes.

Menos obviamente, un resultado de tupla también es una forma de asignarle a tu temperatura de llamada personalizada. memoria. Sí, una salida puede representar un búfer temporal. Considera que un búfer de salida tiene la propiedad en la que la operación puede escribir y puede leer después de que se haya escrito. Eso es exactamente lo que quieres de un búfer temporal.

En el ejemplo anterior, supongamos que queremos usar F32[1024] como búfer temporal. Luego, escribiríamos la HLO de la misma manera que antes y nunca leeríamos el índice de tupla 1 del resultado de la llamada personalizada.