เอกสารนี้อธิบายวิธีเขียนและใช้การเรียก XLA ที่กำหนดเองโดยใช้ XLA FFI ไลบรารี การเรียกที่กำหนดเองเป็นกลไกในการอธิบาย "การดำเนินการ" ภายนอก ในช่วง โมดูล HLO ไปยังคอมไพเลอร์ XLA (ในเวลาคอมไพเลอร์) และ XLA FFI เป็นกลไกในการเปลี่ยนแปลง ลงทะเบียนการใช้งานการดำเนินการดังกล่าวด้วย XLA (ขณะรันไทม์) FFI ยืน สำหรับ "อินเทอร์เฟซฟังก์ชันต่างประเทศ" และเป็นชุด C API ที่กำหนดไบนารี อินเทอร์เฟซ (ABI) สำหรับ XLA เพื่อเรียกใช้งานโค้ดภายนอกที่เขียนไว้ในโปรแกรมอื่นๆ ภาษา XLA ให้การเชื่อมโยงเฉพาะส่วนหัวสําหรับ XLA FFI ที่เขียนด้วย C++ ซึ่งจะซ่อนรายละเอียดระดับล่างทั้งหมดของ C API ที่เกี่ยวข้องจากผู้ใช้ปลายทาง
การโทรแบบกำหนดเอง JAX + XLA
ดูตัวอย่างการผสานรวมการเรียกที่กำหนดเองและ XLA FFI กับ JAX ตั้งแต่ต้นจนจบในเอกสารประกอบ JAX
การเชื่อมโยง FFI ของ XLA
การเชื่อมโยง FFI แบบ XLA คือข้อกำหนดเวลาคอมไพล์ของลายเซ็นเรียกที่กำหนดเอง ดังนี้
อาร์กิวเมนต์การโทรที่กำหนดเอง แอตทริบิวต์และประเภทของอาร์กิวเมนต์ และพารามิเตอร์เพิ่มเติม
ที่ส่งผ่านบริบทการดำเนินการ (เช่น สตรีม GPU สำหรับแบ็กเอนด์ GPU) การค้นหา FFI ของ XLA สามารถเชื่อมโยงกับฟังก์ชันใดก็ได้ของ C++ (ตัวชี้ฟังก์ชัน, LAMBDA ฯลฯ) ที่มีลายเซ็น operator()
ที่เข้ากันได้ ตัวแฮนเดิลที่สร้างขึ้นจะถอดรหัสการเรียก FFI ของ XLA
เฟรม (กำหนดโดย C API แบบคงที่) ให้พิมพ์ ตรวจสอบพารามิเตอร์ทั้งหมด และส่งต่อ
ผลลัพธ์ที่ถอดรหัสแล้วไปยังการติดต่อกลับที่ผู้ใช้กำหนดได้
การเชื่อมโยง FFI ของ XLA อาศัยเมตาการจัดโปรแกรมของเทมเพลตเป็นหลักเพื่อให้ คอมไพล์ตัวแฮนเดิลที่สร้างขึ้นเป็นโค้ดเครื่องที่มีประสิทธิภาพสูงสุด ค่าใช้จ่ายเพิ่มเติมของรันไทม์จะอยู่ที่ประมาณ 2-3 นาโนวินาทีสำหรับพารามิเตอร์การเรียกที่กำหนดเองแต่ละรายการ
จุดปรับแต่งของ XLA FFI ที่นำไปใช้เป็นความเชี่ยวชาญพิเศษของเทมเพลต และ
ผู้ใช้สามารถกำหนดวิธีถอดรหัสประเภทที่กำหนดเอง เช่น
เพื่อกำหนดการถอดรหัสที่กำหนดเองสำหรับประเภท enum class
ที่ผู้ใช้กำหนด
การแสดงข้อผิดพลาดจากการเรียกที่กำหนดเอง
การใช้งานการโทรที่กำหนดเองต้องแสดงค่า xla::ffi::Error
เพื่อส่งสัญญาณ
สำเร็จหรือเกิดข้อผิดพลาดในรันไทม์ XLA ซึ่งคล้ายกับ absl::Status
และมีชุดรหัสข้อผิดพลาดเดียวกัน เราไม่ได้ใช้ absl::Status
เนื่องจาก
ไม่มี ABI ที่เสถียร และจะไม่ปลอดภัยที่จะส่งต่อ ABI ระหว่างแบบไดนามิก
โหลดไลบรารีการโทรที่กำหนดเอง และ XLA เอง
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์
XLA ใช้รูปแบบการส่งปลายทางสำหรับผลลัพธ์: การเรียกที่กำหนดเอง (หรือ XLA อื่นๆ) สำหรับเรื่องนั้น) ไม่ต้องจัดสรรหน่วยความจำสำหรับผลลัพธ์ และแทน เขียนไปยังปลายทางที่ส่งผ่านรันไทม์ XLA XLA ใช้บัฟเฟอร์แบบคงที่ การกำหนด และจัดสรรบัฟเฟอร์สำหรับค่าทั้งหมดตามช่วงที่ใช้งานอยู่ที่ เวลาคอมไพล์
ผลลัพธ์ที่ส่งไปยังตัวแฮนเดิล FFI ที่รวมอยู่ในเทมเพลต Result<T>
ซึ่งมีความหมายเหมือนกับพอยน์เตอร์: operator->
ให้สิทธิ์เข้าถึงพารามิเตอร์พื้นฐาน
AnyBuffer
อาร์กิวเมนต์และผลลัพธ์จะให้สิทธิ์เข้าถึงพารามิเตอร์บัฟเฟอร์การเรียกที่กำหนดเองของข้อมูลประเภทใดก็ได้ ซึ่งจะมีประโยชน์เมื่อการเรียกแบบกำหนดเองมีการใช้งานทั่วไปที่ทำงานกับข้อมูลหลายประเภท และการเรียกแบบกำหนดเองมีการจัดเตรียมรันไทม์ตามประเภทข้อมูล AnyBuffer
ให้สิทธิ์เข้าถึงข้อมูลบัฟเฟอร์
ประเภท มิติข้อมูล และตัวชี้ไปยังบัฟเฟอร์
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
อาร์กิวเมนต์และผลลัพธ์ของบัฟเฟอร์แบบจำกัด
Buffer
อนุญาตให้เพิ่มข้อจำกัดในประเภทและลําดับข้อมูลบัฟเฟอร์ โดยตัวแฮนเดิลจะตรวจสอบโดยอัตโนมัติและแสดงข้อผิดพลาดไปยังรันไทม์ XLA หากอาร์กิวเมนต์รันไทม์ไม่ตรงกับลายเซ็นตัวแฮนเดิล FFI
// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
อาร์กิวเมนต์และผลลัพธ์แบบผันแปร
หากจำนวนของอาร์กิวเมนต์และผลลัพธ์อาจแตกต่างกันในแต่ละกรณีของ
การโทรที่กำหนดเองจะถอดรหัสได้ในเวลาที่เรียกใช้โดยใช้ RemainingArgs
และ
RemainingRets
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
อาร์กิวเมนต์และผลลัพธ์ Variadic สามารถประกาศได้หลังจากอาร์กิวเมนต์ปกติและ อย่างไรก็ตาม จะเชื่อมโยงอาร์กิวเมนต์ปกติและผลลัพธ์หลังจากตัวแปรหนึ่งเป็น ผิดกฎหมาย
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Attributes
XLA FFI สนับสนุนการถอดรหัสอัตโนมัติของ mlir::DictionaryAttr
ที่ส่งผ่านเป็น
custom_call
backend_config
ลงในอาร์กิวเมนต์ตัวแฮนเดิล FFI
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
ในตัวอย่างนี้ การเรียกแบบกำหนดเองจะมีอาร์กิวเมนต์บัฟเฟอร์รายการเดียวและแอตทริบิวต์ 2 รายการ และ FFI ของ XLA จะถอดรหัสและส่งไปยังฟังก์ชันที่เรียกใช้ได้ซึ่งผู้ใช้กำหนดไว้โดยอัตโนมัติ
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
แอตทริบิวต์ Enum ที่กำหนดโดยผู้ใช้
FFI ของ XLA สามารถถอดรหัสแอตทริบิวต์ MLIR รวมเป็นที่กำหนดโดยผู้ใช้ได้โดยอัตโนมัติ enum คลาส Enum ต้องมีประเภทจำนวนเต็มพื้นฐานเดียวกัน และการถอดรหัสต้องลงทะเบียนกับ XLA FFI อย่างชัดแจ้ง
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
การเชื่อมโยงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมด
คุณสามารถเข้าถึงแอตทริบิวต์การโทรที่กำหนดเองทั้งหมดในรูปแบบพจนานุกรม และถอดรหัสเฉพาะแอตทริบิวต์ ที่จำเป็นในเวลาที่แสดง
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
แอตทริบิวต์โครงสร้างที่ผู้ใช้กำหนด
XLA FFI สามารถถอดรหัสแอตทริบิวต์พจนานุกรมเป็นสตรูคเจอร์ที่ผู้ใช้กำหนด
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
ในตัวอย่างนี้ range
เป็นแอตทริบิวต์ mlir::DictionaryAttr
และสามารถถอดรหัสเป็นโครงสร้าง C++ โดยอัตโนมัติแทนการเข้าถึงช่องพจนานุกรมตามชื่อ การถอดรหัสต้องลงทะเบียนอย่างชัดเจนกับ
มาโคร XLA_FFI_REGISTER_STRUCT_ATTR_DECODING
(เบื้องหลังที่กำหนด
ความเชี่ยวชาญพิเศษของเทมเพลตในเนมสเปซ ::xla::ffi
จึงต้องเพิ่มมาโครลงใน
เนมสเปซสากล)
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
StructMember<int64_t>("i64"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
ระบบจะโหลดแอตทริบิวต์ที่กำหนดเองจากพจนานุกรมได้เช่นเดียวกับแอตทริบิวต์อื่นๆ ในตัวอย่างด้านล่าง แอตทริบิวต์การโทรที่กำหนดเองทั้งหมดที่ถอดรหัสเป็น
Dictionary
และ range
สามารถเข้าถึงได้ด้วยชื่อ
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
สร้างการโทรที่กำหนดเองบน CPU
คุณสามารถสร้างคำสั่ง HLO ที่แสดงการเรียกที่กำหนดเองผ่าน API ของไคลเอ็นต์ XLA ได้ เช่น โค้ดต่อไปนี้ใช้การเรียกที่กำหนดเองเพื่อคำนวณ A[i] = B[i %
128]+ C[i]
บน CPU (แน่นอนว่าคุณทำได้และควรทำ – ทําสิ่งนี้
ด้วย HLO ปกติ)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
สร้างการเรียกใช้ที่กำหนดเองบน GPU
การลงทะเบียนการโทรที่กำหนดเองของ GPU ด้วย XLA FFI เกือบจะเหมือนกันทุกประการ
ความแตกต่างคือ สำหรับ GPU คุณจะต้องขอสตรีมแพลตฟอร์มที่อยู่เบื้องหลัง
(สตรีม CUDA หรือ ROCM) เพื่อเปิดใช้เคอร์เนลในอุปกรณ์ได้ นี่คือ CUDA
ตัวอย่างที่คำนวณ (A[i] = B[i % 128] + C[i]
) เดียวกับ CPU
รหัสด้านบน
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
โปรดสังเกตก่อนว่าฟังก์ชันการเรียกใช้ที่กำหนดเองของ GPU ยังคงเป็นฟังก์ชันที่ดำเนินการบน
CPU ฟังก์ชัน CPU ของ do_custom_call
มีหน้าที่ในการจัดคิวงาน
บน GPU ที่นี่จะเปิดใช้งานเคอร์เนล CUDA แต่อาจทําอย่างอื่นได้ด้วย เช่น เรียกใช้ cuBLAS
อาร์กิวเมนต์และผลลัพธ์จะอยู่ในโฮสต์ด้วย และสมาชิกข้อมูลจะมีพอยน์เตอร์ไปยังหน่วยความจำของอุปกรณ์ (เช่น GPU) บัฟเฟอร์ที่ส่งไปยังเครื่องจัดการการเรียกที่กำหนดเองจะมี รูปร่างของบัฟเฟอร์อุปกรณ์พื้นฐาน ดังนั้นการเรียกที่กำหนดเองสามารถประมวลผลเคอร์เนลได้ เปิดพารามิเตอร์จากส่วนขยายนั้น
ส่งต่อ Tuple ไปยังการโทรที่กำหนดเอง
ลองดูการเรียกใช้ที่กําหนดเองต่อไปนี้
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
ทั้งใน CPU และ GPU Tuple จะแสดงเป็นอาร์เรย์ของตัวชี้ในหน่วยความจำ เมื่อ XLA เรียกการเรียกที่กำหนดเองด้วยอาร์กิวเมนต์ Tuple หรือผลลัพธ์จะเป็นค่าเดี่ยวและ เป็นอาร์กิวเมนต์หรือผลลัพธ์แบบบัฟเฟอร์ตามปกติ
เอาต์พุตของทูเปิลเป็นบัฟเฟอร์ชั่วคราว
อินพุตทูเปิลในการเรียกใช้ที่กำหนดเองเป็นวิธีที่สะดวก แต่ก็ไม่จำเป็น หากเราไม่สนับสนุนการป้อนข้อมูล tuple ไปยังการเรียกที่กำหนดเอง คุณสามารถ คลายแพคคู่โดยใช้องค์ประกอบ get-tuple ก่อนที่จะส่งไปยังแอตทริบิวต์ที่กำหนดเอง การโทร
ในทางกลับกัน เอาต์พุตของทิวเปิลช่วยให้คุณทําสิ่งต่างๆ ที่ทําไม่ได้
เหตุผลที่ชัดเจนที่ควรใช้เอาต์พุตทูเปิลคือเอาต์พุตทูเปิลเป็นวิธีที่การเรียกแบบกำหนดเอง (หรือการดำเนินการ XLA อื่นๆ) แสดงผลอาร์เรย์อิสระหลายรายการ
แต่ที่ไม่ค่อยชัดว่าเอาต์พุต Tuple ก็เป็นวิธีกำหนดอุณหภูมิการโทรที่กำหนดเองเช่นกัน ความทรงจำ ได้ เอาต์พุตแสดงถึงบัฟเฟอร์ชั่วคราวได้ พิจารณาบัฟเฟอร์เอาต์พุต มีคุณสมบัติที่ op สามารถเขียนได้ และสามารถอ่านจาก เขียนถึง นั่นคือสิ่งที่คุณต้องการจากบัฟเฟอร์ชั่วคราว
ในตัวอย่างด้านบน สมมติว่าเราต้องการใช้ F32[1024]
เป็นบัฟเฟอร์ชั่วคราว
เราจะเขียน HLO เหมือนกับด้านบน และเราไม่อ่านดัชนี tuple 1 เลย
ของเอาต์พุตของการเรียกที่กำหนดเอง