Cuộc gọi tuỳ chỉnh XLA

Tài liệu này mô tả cách viết và sử dụng lệnh gọi tuỳ chỉnh XLA bằng thư viện XLA FFI. Lệnh gọi tuỳ chỉnh là một cơ chế để mô tả một "hoạt động" bên ngoài trong mô-đun HLO cho trình biên dịch XLA (tại thời điểm biên dịch) và XLA FFI là một cơ chế đăng ký triển khai những thao tác đó bằng XLA (tại thời điểm chạy). FFI là viết tắt của "foreign function Interface" (giao diện hàm nước ngoài), và là một tập hợp các API C xác định giao diện nhị phân (ABI) để XLA gọi vào mã bên ngoài được viết bằng các ngôn ngữ lập trình khác. XLA cung cấp các liên kết chỉ tiêu đề cho XLA FFI được viết bằng C++, giúp ẩn mọi thông tin chi tiết cấp thấp của các API C cơ bản đối với người dùng cuối.

Tạo cuộc gọi tuỳ chỉnh trên CPU

Bạn có thể tạo một lệnh HLO đại diện cho một lệnh gọi tuỳ chỉnh thông qua API ứng dụng của XLA. Ví dụ: Mã sau đây dùng lệnh gọi tuỳ chỉnh để tính toán A[i] = B[i % 128]+ C[i] trên CPU. (Tất nhiên là bạn có thể – và nên làm! – thực hiện việc này với HLO thông thường.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Tạo cuộc gọi tuỳ chỉnh trên GPU

Việc đăng ký lệnh gọi tuỳ chỉnh GPU với XLA FFI gần như giống hệt nhau, điểm khác biệt duy nhất là đối với GPU, bạn cần yêu cầu một luồng nền tảng cơ bản (luồng CUDA hoặc ROCM) để có thể khởi chạy hạt nhân trên thiết bị. Sau đây là ví dụ về CUDA thực hiện phép tính tương tự (A[i] = B[i % 128] + C[i]) như mã CPU ở trên.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Trước tiên, hãy lưu ý rằng hàm gọi tuỳ chỉnh của GPU vẫn là một hàm được thực thi trên CPU. Hàm CPU do_custom_call chịu trách nhiệm đưa công việc vào hàng đợi trên GPU. Tại đây, mã sẽ khởi chạy một hạt nhân CUDA, nhưng cũng có thể thực hiện một tác vụ khác, chẳng hạn như gọi cuBLAS.

Các đối số và kết quả cũng nằm trên máy chủ lưu trữ và thành phần dữ liệu chứa một con trỏ tới bộ nhớ thiết bị (cụ thể là GPU). Các vùng đệm được chuyển đến trình xử lý lệnh gọi tuỳ chỉnh có hình dạng của vùng đệm thiết bị cơ bản, vì vậy, lệnh gọi tuỳ chỉnh có thể tính toán các tham số khởi chạy hạt nhân từ các vùng đệm đó.

Chuyển bộ dữ liệu vào các lệnh gọi tuỳ chỉnh

Hãy xem xét lệnh gọi tuỳ chỉnh sau đây.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Trên cả CPU và GPU, một bộ dữ liệu được biểu thị trong bộ nhớ dưới dạng một mảng con trỏ. Khi XLA gọi các lệnh gọi tuỳ chỉnh có đối số bộ dữ liệu hoặc kết quả, thì nó sẽ làm phẳng các đối số đó rồi truyền dưới dạng kết quả hoặc đối số vùng đệm thông thường.

Đầu ra của đầu ra ở dạng vùng đệm tạm thời

Việc nhập dữ liệu đầu vào cho lệnh gọi tuỳ chỉnh rất tiện lợi nhưng không thực sự cần thiết. Nếu chúng tôi không hỗ trợ dữ liệu đầu vào của bộ dữ liệu cho lệnh gọi tuỳ chỉnh, thì bạn luôn có thể giải nén các bộ dữ liệu bằng phần tử get-tuple trước khi chuyển chúng vào lệnh gọi tuỳ chỉnh.

Mặt khác, bộ công cụ đầu ra cho phép bạn làm những việc mà bạn không thể làm nếu không có.

Lý do rõ ràng để có đầu ra của bộ dữ liệu là đầu ra của bộ dữ liệu là cách lệnh gọi tuỳ chỉnh (hoặc bất kỳ hoạt động XLA nào khác) trả về nhiều mảng độc lập.

Nhưng ít rõ ràng hơn, dữ liệu đầu ra của bộ dữ liệu cũng là một cách cung cấp bộ nhớ tạm thời cho lệnh gọi tuỳ chỉnh. Có, đầu ra có thể biểu thị vùng đệm tạm thời. Hãy xem xét vùng đệm đầu ra có thuộc tính mà nhóm vận hành có thể ghi vào đó và có thể đọc từ vùng đệm đó sau khi được ghi vào. Đó chính xác là những gì bạn muốn từ vùng đệm tạm thời.

Trong ví dụ trên, giả sử chúng ta muốn sử dụng F32[1024] làm vùng đệm tạm thời. Sau đó, chúng tôi sẽ viết HLO như trên và đơn giản là không bao giờ đọc chỉ mục bộ dữ liệu 1 của đầu ra của lệnh gọi tuỳ chỉnh.