XLA 自定义通话

本文档介绍了如何使用 XLA FFI 库编写和使用 XLA 自定义调用。自定义调用是一种机制,用于在 HLO 模块中向 XLA 编译器(在编译时)描述外部“操作”,而 XLA FFI 是一种机制,用于向 XLA(在运行时)注册此类操作的实现。FFI 是“外部函数接口”(foreign function interface) 的缩写,是一组 C API,用于定义 XLA 调用其他编程语言编写的外部代码的二进制接口 (ABI)。XLA 为以 C++ 编写的 XLA FFI 提供仅包含头文件的绑定,从而向最终用户隐藏底层 C API 的所有低级详细信息。

JAX + XLA 自定义调用

如需查看将自定义调用和 XLA FFI 与 JAX 集成的端到端示例,请参阅 JAX 文档

XLA FFI 绑定

XLA FFI 绑定是自定义调用签名的编译时规范:自定义调用实参、属性及其类型,以及通过执行上下文(即 GPU 后端的 GPU 流)传递的其他参数。XLA FFI 绑定可以绑定到任何具有兼容 operator() 签名的可调用 C++ 对象(函数指针、lambda 等)。构建的处理程序可解码 XLA FFI 调用帧(由稳定的 C API 定义)、对所有形参进行类型检查,并将解码结果转发给用户定义的回调。

XLA FFI 绑定严重依赖于模板元编程,以便能够将构建的处理程序编译为最有效的机器代码。对于每个自定义调用参数,运行时开销约为几纳秒。

以模板特化形式实现的 XLA FFI 自定义点,用户可以定义如何解码其自定义类型,即可以为用户定义的 enum class 类型定义自定义解码。

从自定义调用返回错误

自定义调用实现必须返回 xla::ffi::Error 值,以向 XLA 运行时发出成功或错误信号。它类似于 absl::Status,并且具有相同的错误代码集。我们不使用 absl::Status,因为它的 ABI 不稳定,在动态加载的自定义调用库和 XLA 本身之间传递它是不安全的。

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

缓冲区实参和结果

XLA 使用目标传递样式来处理结果:自定义调用(或任何其他 XLA 操作)不会为结果分配内存,而是写入由 XLA 运行时传递的目标。XLA 使用静态缓冲区分配,并在编译时根据所有值的生命周期范围为其分配缓冲区。

传递给 FFI 处理程序的结果封装在具有类似指针语义的 Result<T> 模板中:operator-> 可用于访问底层参数。

AnyBuffer 实参和结果可用于访问任何数据类型的自定义调用缓冲区参数。当自定义调用具有适用于多种数据类型的通用实现,并且自定义调用实现基于数据类型运行时间调度时,此功能非常有用。AnyBuffer 提供对缓冲区数据类型、维度和指向缓冲区本身的指针的访问权限。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

受限的缓冲区实参和结果

Buffer 允许对缓冲区数据类型和维度数量添加限制,如果运行时实参与 FFI 处理程序签名不匹配,处理程序将自动检查这些限制并向 XLA 运行时返回错误。

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

可变数量的实参和结果

如果自定义调用的不同实例中的实参和结果数量可能不同,则可以在运行时使用 RemainingArgsRemainingRets 对它们进行解码。

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

可变实参和结果可以在常规实参和结果之后声明,但绑定常规实参和结果之后的可变实参是非法的。

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

属性

XLA FFI 支持将作为 custom_call backend_config 传递的 mlir::DictionaryAttr 自动解码为 FFI 处理程序实参。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

在此示例中,自定义调用具有单个缓冲区实参和两个属性,XLA FFI 可以自动对它们进行解码并传递给用户定义的可调用对象。

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

用户定义的枚举属性

XLA FFI 可以自动将整数 MLIR 属性解码为用户定义的枚举。枚举类必须具有相同的底层整数类型,并且必须通过 XLA FFI 显式注册解码。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

绑定所有自定义通话属性

您可以获取对所有自定义调用属性的访问权限,并以字典形式表示,然后在运行时仅延迟解码所需的属性。

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

用户定义的结构体属性

XLA FFI 可以将字典属性解码为用户定义的结构体。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

在上面的示例中,range 是一个 mlir::DictionaryAttr 属性,它可以自动解码为 C++ 结构体,而不是按名称访问字典字段。必须使用 XLA_FFI_REGISTER_STRUCT_ATTR_DECODING 宏显式注册解码(在幕后,它在 ::xla::ffi 命名空间中定义模板特例化,因此必须将宏添加到全局命名空间)。

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

自定义属性可以像任何其他属性一样从字典中加载。在下面的示例中,所有自定义通话属性都解码为 Dictionary,并且可以通过名称访问 range

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

在 CPU 上创建自定义调用

您可以通过 XLA 的客户端 API 创建表示自定义调用的 HLO 指令。例如,以下代码使用自定义调用在 CPU 上计算 A[i] = B[i % 128]+ C[i]。(当然可以,而且应该这样做!- 使用常规 HLO 执行此操作。)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

在 GPU 上创建自定义调用

使用 XLA FFI 注册 GPU 自定义调用的方式几乎完全相同,唯一的区别在于,对于 GPU,您需要请求底层平台流(CUDA 或 ROCM 流),以便能够在设备上启动内核。以下是一个 CUDA 示例,它执行与上述 CPU 代码相同的计算 (A[i] = B[i % 128] + C[i])。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

首先请注意,GPU 自定义调用函数仍然是在 CPU 上执行的函数do_custom_call CPU 函数负责将工作加入 GPU 的队列。此处启动了一个 CUDA 内核,但它也可以执行其他操作,例如调用 cuBLAS。

实参和结果也位于主机上,并且数据成员包含指向设备(即 GPU)内存的指针。传递给自定义调用处理程序的缓冲区具有底层设备缓冲区的形状,因此自定义调用可以从中计算内核启动参数。

将元组传递给自定义调用

请考虑以下自定义调用。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

在 CPU 和 GPU 上,元组在内存中都表示为指针数组。当 XLA 使用元组实参或结果调用自定义调用时,它会将其扁平化并作为常规缓冲区实参或结果传递。

将元组输出作为临时缓冲区

元组输入对于自定义调用来说很方便,但并非绝对必要。如果我们不支持向自定义调用传递元组输入,您始终可以使用 get-tuple-element 解封装元组,然后再将其传递给自定义调用。

另一方面,元组输出可让您执行其他方式无法实现的操作。

使用元组输出的明显原因是,自定义调用(或任何其他 XLA 操作)就是通过元组输出来返回多个独立数组的。

但不太明显的是,元组输出也是为自定义调用提供临时内存的一种方式。可以,输出可以表示临时缓冲区。假设一个输出缓冲区具有以下属性:操作可以写入该缓冲区,并且在写入后可以从中读取数据。这正是您希望临时缓冲区发挥的作用。

在上面的示例中,假设我们想使用 F32[1024] 作为临时缓冲区。然后,我们会像上面一样编写 HLO,只是永远不会读取自定义调用的输出的元组索引 1。