เอกสารนี้อธิบายวิธีเขียนและใช้การเรียกที่กำหนดเองของ XLA โดยใช้ไลบรารี XLA FFI การเรียกที่กำหนดเองเป็นกลไกในการอธิบาย "การดำเนินการ" ภายนอกในโมดูล HLO ให้กับคอมไพเลอร์ XLA (ในเวลาคอมไพล์) และ XLA FFI เป็นกลไกในการลงทะเบียนการใช้งานการดำเนินการดังกล่าวกับ XLA (ในเวลาเรียกใช้) FFI ย่อมาจาก "Foreign Function Interface" และเป็นชุด API ของ C ที่กำหนดอินเทอร์เฟซไบนารี (ABI) สำหรับ XLA เพื่อเรียกใช้โค้ดภายนอกที่เขียนด้วยภาษาโปรแกรมอื่นๆ XLA มีการเชื่อมโยงส่วนหัวเท่านั้นสำหรับ XLA FFI ที่เขียนด้วย C++ ซึ่ง ซ่อนรายละเอียดระดับต่ำทั้งหมดของ C API พื้นฐานจากผู้ใช้
การเรียกที่กำหนดเองของ JAX + XLA
ดูเอกสารประกอบของ JAX สำหรับ ตัวอย่างแบบครบวงจรของการผสานรวมการเรียกที่กำหนดเองและ XLA FFI กับ JAX
การเชื่อมโยง FFI ของ XLA
การเชื่อมโยง FFI ของ XLA เป็นข้อกำหนดเวลาคอมไพล์ของลายเซ็นการเรียกที่กำหนดเอง
อาร์กิวเมนต์การเรียกที่กำหนดเอง แอตทริบิวต์และประเภทของอาร์กิวเมนต์ และพารามิเตอร์เพิ่มเติม
ที่ส่งผ่านบริบทการดำเนินการ (เช่น สตรีม GPU สำหรับแบ็กเอนด์ GPU) การเชื่อมโยง XLA FFI
สามารถเชื่อมโยงกับ C++ ที่เรียกใช้ได้ (ตัวชี้ฟังก์ชัน, แลมบ์ดา ฯลฯ) ที่มีoperator()ลายเซ็นที่เข้ากันได้ ตัวแฮนเดิลที่สร้างขึ้นจะถอดรหัสการเรียก XLA FFI
เฟรม (กำหนดโดย C API ที่เสถียร) ตรวจสอบประเภทพารามิเตอร์ทั้งหมด และส่งต่อ
ผลลัพธ์ที่ถอดรหัสแล้วไปยัง Callback ที่ผู้ใช้กำหนด
การเชื่อมโยง XLA FFI อาศัยการเขียนโปรแกรมเมตาเทมเพลตอย่างมากเพื่อให้สามารถ คอมไพล์แฮนเดิลอร์ที่สร้างขึ้นเป็นโค้ดเครื่องที่มีประสิทธิภาพสูงสุด ค่าใช้จ่ายรันไทม์ จะอยู่ในลําดับของ 2-3 นาโนวินาทีสําหรับพารามิเตอร์การเรียกที่กําหนดเองแต่ละรายการ
จุดปรับแต่ง FFI ของ XLA ที่ใช้เป็นการปรับแต่งเทมเพลต และ
ผู้ใช้สามารถกำหนดวิธีถอดรหัสประเภทที่กำหนดเองได้ กล่าวคือ สามารถ
กำหนดการถอดรหัสที่กำหนดเองสำหรับประเภท enum class ที่ผู้ใช้กำหนดได้
การแสดงข้อผิดพลาดจากการเรียกที่กำหนดเอง
การติดตั้งใช้งานการเรียกที่กำหนดเองต้องแสดงผลxla::ffi::Error value เพื่อส่งสัญญาณ
ความสำเร็จหรือข้อผิดพลาดไปยังรันไทม์ XLA ซึ่งคล้ายกับ absl::Status และมีชุดรหัสข้อผิดพลาดเดียวกัน เราไม่ได้ใช้ absl::Status เนื่องจากไม่มี ABI ที่เสถียร และการส่งผ่านระหว่างไลบรารีการเรียกที่กำหนดเองที่โหลดแบบไดนามิกกับ XLA เองนั้นจะไม่ปลอดภัย
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์
XLA ใช้รูปแบบการส่งผ่านปลายทางสำหรับผลลัพธ์ โดยการเรียกที่กำหนดเอง (หรือการดำเนินการ XLA อื่นๆ) จะไม่จัดสรรหน่วยความจำสำหรับผลลัพธ์ แต่จะเขียนลงในปลายทางที่ส่งผ่านโดยรันไทม์ของ XLA แทน XLA ใช้การกำหนดบัฟเฟอร์แบบคงที่และจัดสรรบัฟเฟอร์สำหรับค่าทั้งหมดตามช่วงที่มีการใช้งานในเวลาคอมไพล์
ผลลัพธ์ที่ส่งไปยังแฮนเดิล FFI จะอยู่ในเทมเพลต Result<T> ซึ่งมี
ความหมายคล้ายกับพอยน์เตอร์: operator-> ช่วยให้เข้าถึงพารามิเตอร์พื้นฐานได้
AnyBuffer arguments และ results ให้สิทธิ์เข้าถึงพารามิเตอร์บัฟเฟอร์การเรียกที่กำหนดเอง
ของข้อมูลประเภทใดก็ได้ ซึ่งจะเป็นประโยชน์เมื่อการเรียกที่กำหนดเองมีการติดตั้งใช้งานทั่วไป
ที่ใช้ได้กับข้อมูลหลายประเภท และการติดตั้งใช้งานการเรียกที่กำหนดเองจะเรียกใช้การส่งรันไทม์
ตามประเภทข้อมูล AnyBuffer ให้สิทธิ์เข้าถึงข้อมูลบัฟเฟอร์
ประเภท มิติข้อมูล และพอยน์เตอร์ไปยังบัฟเฟอร์เอง
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์ที่จำกัด
Buffer ช่วยให้เพิ่มข้อจำกัดในประเภทข้อมูลบัฟเฟอร์และจำนวน
มิติข้อมูลได้ และตัวแฮนเดิลจะตรวจสอบข้อจำกัดเหล่านี้โดยอัตโนมัติและแสดง
ข้อผิดพลาดไปยังรันไทม์ XLA หากอาร์กิวเมนต์รันไทม์ไม่ตรงกับลายเซ็นของตัวแฮนเดิล FFI
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
อาร์กิวเมนต์และผลลัพธ์แบบ Variadic
หากจำนวนอาร์กิวเมนต์และผลลัพธ์อาจแตกต่างกันในอินสแตนซ์ต่างๆ ของ
การเรียกที่กำหนดเอง คุณจะถอดรหัสได้ในรันไทม์โดยใช้ RemainingArgs และ
RemainingRets
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
คุณสามารถประกาศอาร์กิวเมนต์และผลลัพธ์แบบ Variadic หลังจากอาร์กิวเมนต์และผลลัพธ์ปกติได้ แต่การเชื่อมโยงอาร์กิวเมนต์และผลลัพธ์ปกติหลังจากอาร์กิวเมนต์แบบ Variadic จะไม่ถูกต้อง
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Attributes
FFI ของ XLA รองรับการถอดรหัสอัตโนมัติของ mlir::DictionaryAttr ที่ส่งเป็น
custom_call backend_config ลงในอาร์กิวเมนต์ของตัวแฮนเดิล FFI
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
ในตัวอย่างนี้ การเรียกที่กำหนดเองมีอาร์กิวเมนต์บัฟเฟอร์เดียวและแอตทริบิวต์ 2 รายการ และ XLA FFI สามารถถอดรหัสอาร์กิวเมนต์และแอตทริบิวต์เหล่านั้นโดยอัตโนมัติ แล้วส่งไปยังฟังก์ชันที่เรียกได้ซึ่งผู้ใช้กำหนด
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
แอตทริบิวต์ Enum ที่ผู้ใช้กำหนด
FFI ของ XLA สามารถถอดรหัสแอตทริบิวต์ MLIR ที่เป็นจำนวนเต็มเป็น enum ที่ผู้ใช้กำหนดได้โดยอัตโนมัติ คลาส Enum ต้องมีประเภทจำนวนเต็มพื้นฐานเดียวกัน และต้องลงทะเบียนการถอดรหัสกับ XLA FFI อย่างชัดเจน
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
การเชื่อมโยงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมด
คุณสามารถเข้าถึงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมดเป็นพจนานุกรม และถอดรหัสเฉพาะแอตทริบิวต์ที่จำเป็นในขณะรันไทม์ได้
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
แอตทริบิวต์ Struct ที่ผู้ใช้กำหนด
XLA FFI สามารถถอดรหัสแอตทริบิวต์พจนานุกรมเป็นโครงสร้างที่ผู้ใช้กำหนดได้
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
ในตัวอย่างข้างต้น range คือแอตทริบิวต์ mlir::DictionaryAttr และแทนที่จะเข้าถึงฟิลด์พจนานุกรมตามชื่อ ระบบจะถอดรหัสโดยอัตโนมัติเป็นโครงสร้าง C++ ได้
ต้องลงทะเบียนการถอดรหัสอย่างชัดเจนด้วยมาโคร a
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (เบื้องหลังจะกำหนด
การปรับแต่งเทมเพลตในเนมสเปซ ::xla::ffi ดังนั้นจึงต้องเพิ่มมาโครลงใน
เนมสเปซส่วนกลาง)
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
คุณโหลดแอตทริบิวต์ที่กำหนดเองจากพจนานุกรมได้เช่นเดียวกับแอตทริบิวต์อื่นๆ
ในตัวอย่างด้านล่าง แอตทริบิวต์การโทรที่กำหนดเองทั้งหมดจะถอดรหัสเป็น a
Dictionary และเข้าถึง range ได้ตามชื่อ
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
สร้างการเรียกที่กำหนดเองใน CPU
คุณสร้างคำสั่ง HLO ที่แสดงถึงการเรียกที่กำหนดเองผ่านไคลเอ็นต์ของ XLA
API ได้ ตัวอย่างเช่น โค้ดต่อไปนี้ใช้การเรียกที่กำหนดเองเพื่อคำนวณ A[i] = B[i %
128]+ C[i] ใน CPU (แน่นอนว่าคุณทำได้และควรทำด้วย - ทำสิ่งนี้
ด้วย HLO ปกติ)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
สร้างการเรียกที่กำหนดเองใน GPU
การลงทะเบียนการเรียกที่กำหนดเองของ GPU ด้วย XLA FFI นั้นเกือบจะเหมือนกัน ความแตกต่างเพียงอย่างเดียวคือสำหรับ GPU คุณต้องขอสตรีมแพลตฟอร์มพื้นฐาน (สตรีม CUDA หรือ ROCM) เพื่อให้สามารถเปิดใช้เคอร์เนลในอุปกรณ์ได้ ต่อไปนี้คือตัวอย่าง CUDA
ที่ทำการคำนวณเดียวกัน (A[i] = B[i % 128] + C[i]) กับโค้ด CPU
ด้านบน
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
โปรดทราบว่าฟังก์ชันการเรียกที่กำหนดเองของ GPU ยังคงเป็นฟังก์ชันที่ดำเนินการใน CPU ฟังก์ชัน CPU ของ do_custom_call มีหน้าที่จัดคิวงาน
ใน GPU ในที่นี้จะเปิดใช้เคอร์เนล CUDA แต่ก็อาจทำอย่างอื่นได้เช่นกัน
เช่น เรียกใช้ cuBLAS
อาร์กิวเมนต์และผลลัพธ์จะอยู่ในโฮสต์ และสมาชิกข้อมูลจะมีพอยน์เตอร์ ไปยังหน่วยความจำของอุปกรณ์ (เช่น GPU) บัฟเฟอร์ที่ส่งไปยังตัวแฮนเดิลการโทรที่กำหนดเองจะมีรูปร่างเหมือนบัฟเฟอร์ของอุปกรณ์พื้นฐาน ดังนั้นการโทรที่กำหนดเองจึงสามารถคำนวณพารามิเตอร์การเปิดใช้เคอร์เนลจากบัฟเฟอร์เหล่านั้นได้
การส่งผ่านทูเพิลไปยังการเรียกที่กำหนดเอง
ลองพิจารณาการเรียกที่กำหนดเองต่อไปนี้
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
ในทั้ง CPU และ GPU ระบบจะแสดง Tuple ในหน่วยความจำเป็นอาร์เรย์ของพอยน์เตอร์ เมื่อ XLA เรียกการเรียกที่กำหนดเองด้วยอาร์กิวเมนต์หรือผลลัพธ์ของ Tuple ระบบจะทำให้แบนและ ส่งเป็นอาร์กิวเมนต์หรือผลลัพธ์ของบัฟเฟอร์ปกติ
เอาต์พุตของ Tuple เป็นบัฟเฟอร์ชั่วคราว
อินพุต Tuple สำหรับการเรียกที่กำหนดเองเป็นเพียงความสะดวก แต่ไม่จำเป็นต้องใช้ อย่างเคร่งครัด หากเราไม่รองรับอินพุต Tuple ในการเรียกที่กำหนดเอง คุณจะสามารถ แกะ Tuple โดยใช้ get-tuple-element ก่อนที่จะส่งไปยังการเรียกที่กำหนดเอง ได้เสมอ
ในทางกลับกัน เอาต์พุตของ Tuple จะช่วยให้คุณทำสิ่งต่างๆ ที่ทำไม่ได้ในกรณีอื่น
เหตุผลที่ชัดเจนในการมีเอาต์พุตของ Tuple คือเอาต์พุตของ Tuple เป็นวิธีที่การเรียกที่กำหนดเอง (หรือการดำเนินการ XLA อื่นๆ) จะแสดงผลอาร์เรย์อิสระหลายรายการ
แต่ที่เห็นได้ไม่ชัดเจนนักคือเอาต์พุตของ Tuple ยังเป็นวิธีในการให้เทมเพลตการเรียกที่กำหนดเอง หน่วยความจำ ได้ เอาต์พุตสามารถแสดงบัฟเฟอร์ชั่วคราวได้ พิจารณาว่าบัฟเฟอร์เอาต์พุต มีพร็อพเพอร์ตี้ที่การดำเนินการสามารถเขียนลงในบัฟเฟอร์นั้นได้ และอ่านจากบัฟเฟอร์นั้นได้หลังจาก เขียนลงไปแล้ว ซึ่งเป็นสิ่งที่คุณต้องการจากบัฟเฟอร์ชั่วคราว
ในตัวอย่างด้านบน สมมติว่าเราต้องการใช้ F32[1024] เป็นบัฟเฟอร์ชั่วคราว
จากนั้นเราจะเขียน HLO เหมือนกับด้านบน และจะไม่อ่านดัชนี Tuple 1
ของเอาต์พุตการเรียกที่กำหนดเอง