XLA 自定义通话

本文档介绍了如何编写和使用 XLA 自定义调用。借助自定义调用,您可以从 XLA 程序中调用使用 C++ 或 CUDA 等编程语言编写的代码。

在 CPU 上创建自定义调用

您可以通过 XLA 的客户端 API 创建表示自定义调用的 HLO 指令。例如,以下代码使用自定义调用在 CPU 上计算 A[i] = B[i % 128]+ C[i]。(当然可以,而且应该!请使用常规 HLO 执行此操作。)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

请注意,do_custom_call 函数需要知道其操作的缓冲区的尺寸。在本示例中,我们对尺寸 1282048 进行了硬编码。如果您不希望这样做,则可以将维度作为参数传入调用。

在 GPU 上创建自定义调用

GPU 自定义调用框架与 CPU 上的调用框架略有不同。下面的 CUDA 示例执行与上述 CPU 代码相同的计算 (A[i] = B[i % 128] + C[i])。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

首先请注意,GPU 自定义调用函数仍然是在 CPU 上执行的函数do_custom_call CPU 函数负责将 GPU 上的工作加入队列。在这里,它会启动 CUDA 内核,但也可能会执行其他任务,例如调用 cuBLAS。

buffers 是位于主机上的指针数组,其中包含的每个元素都指向设备(即 GPU)内存。参数在前,后跟输出值。这与 CPU 调用规范明显不同,后者有两个参数:insout。借助 GPU 调用规范,您可以高效处理元组形状的输入/输出。

与 CPU 示例一样,我们已将输入和输出缓冲区大小硬编码到自定义调用中。但是,与 CPU 的情况不同,将缓冲区大小作为操作数传递给自定义调用的效果并不理想。通常需要在 CPU 上可用的缓冲区大小(例如,在启动内核时,我们需要知道要使用的块/网格尺寸)。但是,如果我们要将缓冲区大小作为操作数传递给自定义调用,其值将位于 GPU 内存中。然后,我们必须在操作开始时执行开销非常大的设备到主机 memcpy,以便读取大小。

为帮助您解决此问题,我们提供了 opaque 参数。在创建自定义调用时,您可以将此属性设置为任意字节字符串:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

由于 xla::Shape 具有协议缓冲区表示法,因此您可以将此序列化 proto 存储在 opaque 内,并在 GPU 自定义调用中对其进行反序列化。但请注意,虽然 xla::ShapeProto 不会经常更改,但它确实会更改。查看 Git 日志,了解过去的变化情况。

发出错误信号

如果自定义调用遇到错误,您可以通过对函数使用以下签名来向 XLA 运行时发出错误信号(而不是例如在输出缓冲区中崩溃或返回无意义内容):

在 CPU 上

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

在 GPU 上

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

您可以使用 XlaCustomCallStatusSetFailure 表示失败,例如:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

您还可以使用 XlaCustomCallStatusSetSuccess 来指示成功,但 XlaCustomCallStatus 默认处于成功状态,因此完全忽略该状态也表示成功。

使用此签名使用自定义调用函数时,您必须使用适当的 API 版本集创建相应的 custom-call 操作,例如:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

失败时,将不会使用任何自定义调用输出;XLA 运行时将终止计算。HLO 计算无法从错误中恢复(例如,通过捕获和处理错误)。

将元组传递给自定义调用

考虑以下自定义调用。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

在 CPU 和 GPU 上,元组在内存中表示为指针数组。在 C++ 伪代码中,上述参数 0 的布局如下。

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

尽管元组在内存中的表示法在 CPU 和 GPU 中相同,但在 CPU 和 GPU 自定义调用调用规范中的处理方式不同。

作为临时缓冲区的元组输出

自定义调用的元组输入只是为了方便起见,但并非绝对必需。如果我们不支持对自定义调用的元组输入,您始终可以在将元组传递给自定义调用之前使用 get-tuple-element 解压缩这些元组。

另一方面,元组输出可让您执行通过其他方式无法实现的一些操作。

具有元组输出的明显原因是元组输出是自定义调用(或任何其他 XLA 操作)返回多个独立数组的方式。

但不太明显,元组输出也是一种提供自定义调用临时内存的方式。可以,输出可以表示临时缓冲区。设想一下,输出缓冲区具有一项操作,可向其写入数据,且可在其被写入后从该缓冲区读取数据的属性。这正是你想要从临时缓冲区获得的内容。

在上面的示例中,假设我们要将 F32[1024] 用作临时缓冲区。然后,我们如上所述编写 HLO,并且永远不会读取自定义调用输出的元组索引 1。

CPU 自定义调用中的元组数

在 CPU 代码中,有一个函数 do_custom_call(const void** ins, void* out)ins 是一个仅包含一个元素的数组,该数组指向 param0。您可以通过解引用该指针来访问 param0 的子缓冲区,通过解引用 out 来访问 output_tuple 的子缓冲区。

GPU 自定义调用中的元组数

在 GPU 代码中,有一个函数 do_custom_call(..., void** buffers, ...)。在本例中,buffers 是由六个设备指针组成的主机数组,输入/输出中的每个叶缓冲区对应一个指针。为了生成扁平列表,我们需要迭代参数和输出,并针对每项参数和输出对其形状进行预排序遍历。具体而言:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1