การโทรที่กำหนดเองสำหรับ XLA

เอกสารนี้จะอธิบายวิธีเขียนและใช้การเรียก XLA ที่กำหนดเองโดยใช้ไลบรารี XLA FFI การเรียกที่กำหนดเองเป็นกลไกในการอธิบาย "การดำเนินการ" ภายนอกในโมดูล HLO ไปยังคอมไพเลอร์ XLA (ขณะคอมไพล์) และ XLA FFI เป็นกลไกในการลงทะเบียนการใช้งานการดำเนินการดังกล่าวกับ XLA (ในเวลาที่เรียกใช้) FFI ย่อมาจาก "Foreign Function Interface" และเป็นชุดของ C API ที่กำหนดอินเทอร์เฟซแบบไบนารี (ABI) สำหรับ XLA ในการเรียกใช้โค้ดภายนอกที่เขียนในภาษาโปรแกรมอื่นๆ XLA ให้บริการการเชื่อมโยงส่วนหัวเท่านั้นสำหรับ XLA FFI ที่เขียนด้วย C++ ซึ่งจะซ่อนรายละเอียดระดับต่ำทั้งหมดของ C API ที่เกี่ยวข้องจากผู้ใช้ปลายทาง

สร้างการโทรที่กำหนดเองโดยใช้ CPU

คุณสามารถสร้างคำสั่ง HLO ซึ่งแสดงถึงการเรียกที่กำหนดเองผ่าน API ไคลเอ็นต์ของ XLA ได้ ตัวอย่างเช่น โค้ดต่อไปนี้ใช้การเรียกที่กำหนดเองเพื่อคำนวณ A[i] = B[i % 128]+ C[i] บน CPU (แน่นอนว่าคุณทำได้และควรทำ – ให้ทำด้วย HLO ปกติ)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

สร้างการโทรที่กำหนดเองบน GPU

การลงทะเบียนการเรียกใช้ที่กำหนดเองของ GPU กับ XLA FFI แทบจะเหมือนกันทั้งหมด ความแตกต่างเพียงอย่างเดียวคือสำหรับ GPU คุณต้องขอสตรีมแพลตฟอร์มที่สำคัญ (สตรีม CUDA หรือ ROCM) เพื่อให้สามารถเปิดเคอร์เนลในอุปกรณ์ได้ ต่อไปนี้คือตัวอย่าง CUDA ที่ประมวลผลแบบเดียวกับโค้ด CPU ด้านบน (A[i] = B[i % 128] + C[i])

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

โปรดสังเกตก่อนว่าฟังก์ชันการโทรที่กำหนดเองของ GPU ยังคงเป็นฟังก์ชันที่ดำเนินการบน CPU ฟังก์ชัน CPU ของ do_custom_call มีหน้าที่กำหนดคิวงานบน GPU ซึ่งในตัวอย่างนี้ เคอร์เนลของ CUDA สามารถทำอะไรอย่างอื่นได้ เช่น เรียกใช้ cuBLAS

อาร์กิวเมนต์และผลลัพธ์ยังพร้อมใช้งานในโฮสต์และสมาชิกข้อมูลมีตัวชี้ไปยังหน่วยความจำของอุปกรณ์ (เช่น GPU) บัฟเฟอร์ที่ส่งไปยังเครื่องจัดการการเรียกที่กำหนดเองจะมีรูปร่างของบัฟเฟอร์อุปกรณ์ที่สำคัญ ดังนั้นการเรียกใช้ที่กำหนดเองจึงคำนวณพารามิเตอร์การเรียกใช้เคอร์เนลจากบัฟเฟอร์เหล่านั้นได้

การส่งต่อไปยังการโทรที่กำหนดเอง

ลองพิจารณาการเรียกที่กำหนดเองต่อไปนี้

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

ทั้ง CPU และ GPU จะแสดง Tuple ในหน่วยความจำเป็นอาร์เรย์ของตัวชี้ เมื่อ XLA เรียกใช้การเรียกที่กำหนดเองที่มีอาร์กิวเมนต์ tuple หรือผลลัพธ์ให้แบนราบและส่งผ่านเป็นอาร์กิวเมนต์หรือผลลัพธ์บัฟเฟอร์ปกติ

เอาต์พุต Tuple เป็นบัฟเฟอร์ชั่วคราว

การป้อนข้อมูลเป็น 2 ส่วนในการโทรที่กำหนดเองนั้นสะดวก แต่ไม่ได้จำเป็นเสมอไป หากเราไม่รองรับอินพุต Tuple ในการเรียกที่กำหนดเอง คุณสามารถคลายแพ็ก Tuple ได้ทุกเมื่อโดยใช้องค์ประกอบ get-tuple ก่อนส่งไปยังการโทรที่กำหนดเอง

ในทางกลับกัน เอาต์พุตของ Tuple ช่วยให้คุณสามารถทำในสิ่งที่ทำไม่ได้

เหตุผลที่ชัดเจนของการมีเอาต์พุต Tuple คือเอาต์พุต Tuple คือวิธีที่การเรียกที่กำหนดเอง (หรือ XLA OP อื่นๆ) แสดงผลอาร์เรย์อิสระหลายรายการ

แต่เห็นได้ชัดน้อยกว่านั้น เอาต์พุต Tuple ก็เป็นอีกวิธีหนึ่งในการสร้างหน่วยความจำชั่วคราวสำหรับการโทรที่กำหนดเองของคุณ มี เอาต์พุตสามารถแสดงบัฟเฟอร์ชั่วคราวได้ ลองพิจารณาว่าบัฟเฟอร์เอาต์พุตจะมีพร็อพเพอร์ตี้ที่ฝ่ายเขียนสามารถเขียนลงในที่เก็บได้ และจะอ่านได้จากบัฟเฟอร์ดังกล่าวหลังจากที่เขียนลงไปแล้ว นั่นคือสิ่งที่คุณต้องการจากบัฟเฟอร์ชั่วคราว

ในตัวอย่างด้านบน สมมติว่าเราต้องการใช้ F32[1024] เป็นบัฟเฟอร์ชั่วคราว จากนั้น เราจะเขียน HLO เหมือนด้านบน และเราจะไม่อ่านดัชนี Tuple 1 ของเอาต์พุตของการเรียกที่กำหนดเอง