การเรียกที่กำหนดเองของ XLA

เอกสารนี้อธิบายวิธีเขียนและใช้การเรียกที่กำหนดเองของ XLA โดยใช้ไลบรารี XLA FFI การเรียกที่กำหนดเองเป็นกลไกในการอธิบาย "การดำเนินการ" ภายนอกในโมดูล HLO ให้กับคอมไพเลอร์ XLA (ในเวลาคอมไพล์) และ XLA FFI เป็นกลไกในการลงทะเบียนการใช้งานการดำเนินการดังกล่าวกับ XLA (ในเวลาเรียกใช้) FFI ย่อมาจาก "Foreign Function Interface" และเป็นชุด API ของ C ที่กำหนดอินเทอร์เฟซไบนารี (ABI) สำหรับ XLA เพื่อเรียกใช้โค้ดภายนอกที่เขียนด้วยภาษาโปรแกรมอื่นๆ XLA มีการเชื่อมโยงส่วนหัวเท่านั้นสำหรับ XLA FFI ที่เขียนด้วย C++ ซึ่ง ซ่อนรายละเอียดระดับต่ำทั้งหมดของ C API พื้นฐานจากผู้ใช้

การเรียกที่กำหนดเองของ JAX + XLA

ดูเอกสารประกอบของ JAX สำหรับ ตัวอย่างแบบครบวงจรของการผสานรวมการเรียกที่กำหนดเองและ XLA FFI กับ JAX

การเชื่อมโยง FFI ของ XLA

การเชื่อมโยง FFI ของ XLA เป็นข้อกำหนดเวลาคอมไพล์ของลายเซ็นการเรียกที่กำหนดเอง อาร์กิวเมนต์การเรียกที่กำหนดเอง แอตทริบิวต์และประเภทของอาร์กิวเมนต์ และพารามิเตอร์เพิ่มเติม ที่ส่งผ่านบริบทการดำเนินการ (เช่น สตรีม GPU สำหรับแบ็กเอนด์ GPU) การเชื่อมโยง XLA FFI สามารถเชื่อมโยงกับ C++ ที่เรียกใช้ได้ (ตัวชี้ฟังก์ชัน, แลมบ์ดา ฯลฯ) ที่มีoperator()ลายเซ็นที่เข้ากันได้ ตัวแฮนเดิลที่สร้างขึ้นจะถอดรหัสการเรียก XLA FFI เฟรม (กำหนดโดย C API ที่เสถียร) ตรวจสอบประเภทพารามิเตอร์ทั้งหมด และส่งต่อ ผลลัพธ์ที่ถอดรหัสแล้วไปยัง Callback ที่ผู้ใช้กำหนด

การเชื่อมโยง XLA FFI อาศัยการเขียนโปรแกรมเมตาเทมเพลตอย่างมากเพื่อให้สามารถ คอมไพล์แฮนเดิลอร์ที่สร้างขึ้นเป็นโค้ดเครื่องที่มีประสิทธิภาพสูงสุด ค่าใช้จ่ายรันไทม์ จะอยู่ในลําดับของ 2-3 นาโนวินาทีสําหรับพารามิเตอร์การเรียกที่กําหนดเองแต่ละรายการ

จุดปรับแต่ง FFI ของ XLA ที่ใช้เป็นการปรับแต่งเทมเพลต และ ผู้ใช้สามารถกำหนดวิธีถอดรหัสประเภทที่กำหนดเองได้ กล่าวคือ สามารถ กำหนดการถอดรหัสที่กำหนดเองสำหรับประเภท enum class ที่ผู้ใช้กำหนดได้

การแสดงข้อผิดพลาดจากการเรียกที่กำหนดเอง

การติดตั้งใช้งานการเรียกที่กำหนดเองต้องแสดงผลxla::ffi::Error value เพื่อส่งสัญญาณ ความสำเร็จหรือข้อผิดพลาดไปยังรันไทม์ XLA ซึ่งคล้ายกับ absl::Status และมีชุดรหัสข้อผิดพลาดเดียวกัน เราไม่ได้ใช้ absl::Status เนื่องจากไม่มี ABI ที่เสถียร และการส่งผ่านระหว่างไลบรารีการเรียกที่กำหนดเองที่โหลดแบบไดนามิกกับ XLA เองนั้นจะไม่ปลอดภัย

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์

XLA ใช้รูปแบบการส่งผ่านปลายทางสำหรับผลลัพธ์ โดยการเรียกที่กำหนดเอง (หรือการดำเนินการ XLA อื่นๆ) จะไม่จัดสรรหน่วยความจำสำหรับผลลัพธ์ แต่จะเขียนลงในปลายทางที่ส่งผ่านโดยรันไทม์ของ XLA แทน XLA ใช้การกำหนดบัฟเฟอร์แบบคงที่และจัดสรรบัฟเฟอร์สำหรับค่าทั้งหมดตามช่วงที่มีการใช้งานในเวลาคอมไพล์

ผลลัพธ์ที่ส่งไปยังแฮนเดิล FFI จะอยู่ในเทมเพลต Result<T> ซึ่งมี ความหมายคล้ายกับพอยน์เตอร์: operator-> ช่วยให้เข้าถึงพารามิเตอร์พื้นฐานได้

AnyBuffer arguments และ results ให้สิทธิ์เข้าถึงพารามิเตอร์บัฟเฟอร์การเรียกที่กำหนดเอง ของข้อมูลประเภทใดก็ได้ ซึ่งจะเป็นประโยชน์เมื่อการเรียกที่กำหนดเองมีการติดตั้งใช้งานทั่วไป ที่ใช้ได้กับข้อมูลหลายประเภท และการติดตั้งใช้งานการเรียกที่กำหนดเองจะเรียกใช้การส่งรันไทม์ ตามประเภทข้อมูล AnyBuffer ให้สิทธิ์เข้าถึงข้อมูลบัฟเฟอร์ ประเภท มิติข้อมูล และพอยน์เตอร์ไปยังบัฟเฟอร์เอง

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์ที่จำกัด

Buffer ช่วยให้เพิ่มข้อจำกัดในประเภทข้อมูลบัฟเฟอร์และจำนวน มิติข้อมูลได้ และตัวแฮนเดิลจะตรวจสอบข้อจำกัดเหล่านี้โดยอัตโนมัติและแสดง ข้อผิดพลาดไปยังรันไทม์ XLA หากอาร์กิวเมนต์รันไทม์ไม่ตรงกับลายเซ็นของตัวแฮนเดิล FFI

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

อาร์กิวเมนต์และผลลัพธ์แบบ Variadic

หากจำนวนอาร์กิวเมนต์และผลลัพธ์อาจแตกต่างกันในอินสแตนซ์ต่างๆ ของ การเรียกที่กำหนดเอง คุณจะถอดรหัสได้ในรันไทม์โดยใช้ RemainingArgs และ RemainingRets

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

คุณสามารถประกาศอาร์กิวเมนต์และผลลัพธ์แบบ Variadic หลังจากอาร์กิวเมนต์และผลลัพธ์ปกติได้ แต่การเชื่อมโยงอาร์กิวเมนต์และผลลัพธ์ปกติหลังจากอาร์กิวเมนต์แบบ Variadic จะไม่ถูกต้อง

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Attributes

FFI ของ XLA รองรับการถอดรหัสอัตโนมัติของ mlir::DictionaryAttr ที่ส่งเป็น custom_call backend_config ลงในอาร์กิวเมนต์ของตัวแฮนเดิล FFI

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

ในตัวอย่างนี้ การเรียกที่กำหนดเองมีอาร์กิวเมนต์บัฟเฟอร์เดียวและแอตทริบิวต์ 2 รายการ และ XLA FFI สามารถถอดรหัสอาร์กิวเมนต์และแอตทริบิวต์เหล่านั้นโดยอัตโนมัติ แล้วส่งไปยังฟังก์ชันที่เรียกได้ซึ่งผู้ใช้กำหนด

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

แอตทริบิวต์ Enum ที่ผู้ใช้กำหนด

FFI ของ XLA สามารถถอดรหัสแอตทริบิวต์ MLIR ที่เป็นจำนวนเต็มเป็น enum ที่ผู้ใช้กำหนดได้โดยอัตโนมัติ คลาส Enum ต้องมีประเภทจำนวนเต็มพื้นฐานเดียวกัน และต้องลงทะเบียนการถอดรหัสกับ XLA FFI อย่างชัดเจน

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

การเชื่อมโยงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมด

คุณสามารถเข้าถึงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมดเป็นพจนานุกรม และถอดรหัสเฉพาะแอตทริบิวต์ที่จำเป็นในขณะรันไทม์ได้

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

แอตทริบิวต์ Struct ที่ผู้ใช้กำหนด

XLA FFI สามารถถอดรหัสแอตทริบิวต์พจนานุกรมเป็นโครงสร้างที่ผู้ใช้กำหนดได้

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

ในตัวอย่างข้างต้น range คือแอตทริบิวต์ mlir::DictionaryAttr และแทนที่จะเข้าถึงฟิลด์พจนานุกรมตามชื่อ ระบบจะถอดรหัสโดยอัตโนมัติเป็นโครงสร้าง C++ ได้ ต้องลงทะเบียนการถอดรหัสอย่างชัดเจนด้วยมาโคร a XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (เบื้องหลังจะกำหนด การปรับแต่งเทมเพลตในเนมสเปซ ::xla::ffi ดังนั้นจึงต้องเพิ่มมาโครลงใน เนมสเปซส่วนกลาง)

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

คุณโหลดแอตทริบิวต์ที่กำหนดเองจากพจนานุกรมได้เช่นเดียวกับแอตทริบิวต์อื่นๆ ในตัวอย่างด้านล่าง แอตทริบิวต์การโทรที่กำหนดเองทั้งหมดจะถอดรหัสเป็น a Dictionary และเข้าถึง range ได้ตามชื่อ

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

สร้างการเรียกที่กำหนดเองใน CPU

คุณสร้างคำสั่ง HLO ที่แสดงถึงการเรียกที่กำหนดเองผ่านไคลเอ็นต์ของ XLA API ได้ ตัวอย่างเช่น โค้ดต่อไปนี้ใช้การเรียกที่กำหนดเองเพื่อคำนวณ A[i] = B[i % 128]+ C[i] ใน CPU (แน่นอนว่าคุณทำได้และควรทำด้วย - ทำสิ่งนี้ ด้วย HLO ปกติ)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

สร้างการเรียกที่กำหนดเองใน GPU

การลงทะเบียนการเรียกที่กำหนดเองของ GPU ด้วย XLA FFI นั้นเกือบจะเหมือนกัน ความแตกต่างเพียงอย่างเดียวคือสำหรับ GPU คุณต้องขอสตรีมแพลตฟอร์มพื้นฐาน (สตรีม CUDA หรือ ROCM) เพื่อให้สามารถเปิดใช้เคอร์เนลในอุปกรณ์ได้ ต่อไปนี้คือตัวอย่าง CUDA ที่ทำการคำนวณเดียวกัน (A[i] = B[i % 128] + C[i]) กับโค้ด CPU ด้านบน

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

โปรดทราบว่าฟังก์ชันการเรียกที่กำหนดเองของ GPU ยังคงเป็นฟังก์ชันที่ดำเนินการใน CPU ฟังก์ชัน CPU ของ do_custom_call มีหน้าที่จัดคิวงาน ใน GPU ในที่นี้จะเปิดใช้เคอร์เนล CUDA แต่ก็อาจทำอย่างอื่นได้เช่นกัน เช่น เรียกใช้ cuBLAS

อาร์กิวเมนต์และผลลัพธ์จะอยู่ในโฮสต์ และสมาชิกข้อมูลจะมีพอยน์เตอร์ ไปยังหน่วยความจำของอุปกรณ์ (เช่น GPU) บัฟเฟอร์ที่ส่งไปยังตัวแฮนเดิลการโทรที่กำหนดเองจะมีรูปร่างเหมือนบัฟเฟอร์ของอุปกรณ์พื้นฐาน ดังนั้นการโทรที่กำหนดเองจึงสามารถคำนวณพารามิเตอร์การเปิดใช้เคอร์เนลจากบัฟเฟอร์เหล่านั้นได้

การส่งผ่านทูเพิลไปยังการเรียกที่กำหนดเอง

ลองพิจารณาการเรียกที่กำหนดเองต่อไปนี้

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

ในทั้ง CPU และ GPU ระบบจะแสดง Tuple ในหน่วยความจำเป็นอาร์เรย์ของพอยน์เตอร์ เมื่อ XLA เรียกการเรียกที่กำหนดเองด้วยอาร์กิวเมนต์หรือผลลัพธ์ของ Tuple ระบบจะทำให้แบนและ ส่งเป็นอาร์กิวเมนต์หรือผลลัพธ์ของบัฟเฟอร์ปกติ

เอาต์พุตของ Tuple เป็นบัฟเฟอร์ชั่วคราว

อินพุต Tuple สำหรับการเรียกที่กำหนดเองเป็นเพียงความสะดวก แต่ไม่จำเป็นต้องใช้ อย่างเคร่งครัด หากเราไม่รองรับอินพุต Tuple ในการเรียกที่กำหนดเอง คุณจะสามารถ แกะ Tuple โดยใช้ get-tuple-element ก่อนที่จะส่งไปยังการเรียกที่กำหนดเอง ได้เสมอ

ในทางกลับกัน เอาต์พุตของ Tuple จะช่วยให้คุณทำสิ่งต่างๆ ที่ทำไม่ได้ในกรณีอื่น

เหตุผลที่ชัดเจนในการมีเอาต์พุตของ Tuple คือเอาต์พุตของ Tuple เป็นวิธีที่การเรียกที่กำหนดเอง (หรือการดำเนินการ XLA อื่นๆ) จะแสดงผลอาร์เรย์อิสระหลายรายการ

แต่ที่เห็นได้ไม่ชัดเจนนักคือเอาต์พุตของ Tuple ยังเป็นวิธีในการให้เทมเพลตการเรียกที่กำหนดเอง หน่วยความจำ ได้ เอาต์พุตสามารถแสดงบัฟเฟอร์ชั่วคราวได้ พิจารณาว่าบัฟเฟอร์เอาต์พุต มีพร็อพเพอร์ตี้ที่การดำเนินการสามารถเขียนลงในบัฟเฟอร์นั้นได้ และอ่านจากบัฟเฟอร์นั้นได้หลังจาก เขียนลงไปแล้ว ซึ่งเป็นสิ่งที่คุณต้องการจากบัฟเฟอร์ชั่วคราว

ในตัวอย่างด้านบน สมมติว่าเราต้องการใช้ F32[1024] เป็นบัฟเฟอร์ชั่วคราว จากนั้นเราจะเขียน HLO เหมือนกับด้านบน และจะไม่อ่านดัชนี Tuple 1 ของเอาต์พุตการเรียกที่กำหนดเอง