Tài liệu này mô tả cách viết và sử dụng các lệnh gọi tuỳ chỉnh XLA bằng thư viện XLA FFI. Lệnh gọi tuỳ chỉnh là một cơ chế để mô tả "hoạt động" bên ngoài trong mô-đun HLO cho trình biên dịch XLA (tại thời điểm biên dịch) và XLA FFI là một cơ chế để đăng ký việc triển khai các hoạt động như vậy với XLA (tại thời điểm chạy). FFI là viết tắt của "giao diện hàm ngoài" và là một tập hợp các API C xác định giao diện nhị phân (ABI) cho XLA để gọi vào mã bên ngoài được viết bằng các ngôn ngữ lập trình khác. XLA cung cấp các liên kết chỉ dành cho tiêu đề cho XLA FFI được viết bằng C++, giúp ẩn tất cả thông tin chi tiết cấp thấp của các API C cơ bản khỏi người dùng cuối.
JAX + XLA Custom Calls
Hãy xem tài liệu JAX để biết các ví dụ từ đầu đến cuối về việc tích hợp các lệnh gọi tuỳ chỉnh và XLA FFI với JAX.
Liên kết FFI XLA
Liên kết XLA FFI là một quy cách thời gian biên dịch của chữ ký lệnh gọi tuỳ chỉnh: đối số lệnh gọi tuỳ chỉnh, thuộc tính và các loại của chúng, cũng như các tham số bổ sung được truyền qua ngữ cảnh thực thi (tức là luồng GPU cho phần phụ trợ GPU). Liên kết XLA FFI có thể được liên kết với mọi hàm có thể gọi C++ (con trỏ hàm, lambda, v.v.) có chữ ký operator() tương thích. Trình xử lý được tạo sẽ giải mã khung lệnh gọi XLA FFI (do API C ổn định xác định), kiểm tra kiểu của tất cả các tham số và chuyển tiếp kết quả đã giải mã đến lệnh gọi lại do người dùng xác định.
Liên kết FFI XLA phụ thuộc nhiều vào siêu lập trình mẫu để có thể biên dịch trình xử lý đã tạo thành mã máy hiệu quả nhất. Thời gian chạy có độ trễ theo thứ tự vài nano giây cho mỗi tham số lệnh gọi tuỳ chỉnh.
Các điểm tuỳ chỉnh XLA FFI được triển khai dưới dạng các chuyên môn hoá mẫu và người dùng có thể xác định cách giải mã các loại tuỳ chỉnh của họ, tức là có thể xác định quy trình giải mã tuỳ chỉnh cho các loại enum class do người dùng xác định.
Trả về lỗi từ các lệnh gọi tuỳ chỉnh
Các phương thức triển khai lệnh gọi tuỳ chỉnh phải trả về giá trị xla::ffi::Error để báo hiệu thành công hoặc lỗi cho thời gian chạy XLA. Phương thức này tương tự như absl::Status và có cùng một bộ mã lỗi. Chúng tôi không sử dụng absl::Status vì nó không có ABI ổn định và sẽ không an toàn khi truyền ABI này giữa thư viện gọi tuỳ chỉnh được tải động và chính XLA.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
Đối số và kết quả của vùng đệm
XLA sử dụng kiểu truyền đích cho kết quả: các lệnh gọi tuỳ chỉnh (hoặc bất kỳ thao tác XLA nào khác) không phân bổ bộ nhớ cho kết quả mà thay vào đó, ghi vào các đích được truyền bởi thời gian chạy XLA. XLA sử dụng việc chỉ định vùng đệm tĩnh và phân bổ vùng đệm cho tất cả các giá trị dựa trên phạm vi hoạt động của chúng tại thời gian biên dịch.
Kết quả được truyền đến các trình xử lý FFI được gói trong một mẫu Result<T> có ngữ nghĩa giống như con trỏ: operator-> cho phép truy cập vào tham số cơ bản.
AnyBuffer đối số và kết quả cho phép truy cập vào các thông số bộ nhớ đệm cuộc gọi tuỳ chỉnh của mọi loại dữ liệu. Điều này hữu ích khi lệnh gọi tuỳ chỉnh có một cách triển khai chung hoạt động cho nhiều loại dữ liệu và cách triển khai lệnh gọi tuỳ chỉnh thực hiện việc phân phối thời gian chạy dựa trên loại dữ liệu. AnyBuffer cho phép truy cập vào loại dữ liệu vùng đệm, các phương diện và một con trỏ trỏ đến chính vùng đệm.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
Đối số và kết quả của vùng đệm bị hạn chế
Buffer cho phép thêm các ràng buộc vào kiểu dữ liệu vùng đệm và số lượng phương diện, đồng thời các ràng buộc này sẽ được trình xử lý tự động kiểm tra và trả về lỗi cho thời gian chạy XLA, nếu các đối số thời gian chạy không khớp với chữ ký trình xử lý FFI.
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
Đối số và kết quả có thể thay đổi
Nếu số lượng đối số và kết quả có thể khác nhau trong các phiên bản khác nhau của một lệnh gọi tuỳ chỉnh, thì chúng có thể được giải mã tại thời gian chạy bằng cách sử dụng RemainingArgs và RemainingRets.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
Bạn có thể khai báo các đối số và kết quả có độ dài thay đổi sau các đối số và kết quả thông thường, tuy nhiên, việc liên kết các đối số và kết quả thông thường sau đối số có độ dài thay đổi là không hợp lệ.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Thuộc tính
FFI XLA hỗ trợ việc tự động giải mã mlir::DictionaryAttr được truyền dưới dạng custom_call backend_config thành các đối số trình xử lý FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
Trong ví dụ này, lệnh gọi tuỳ chỉnh có một đối số vùng đệm và hai thuộc tính, đồng thời XLA FFI có thể tự động giải mã các đối số và thuộc tính này rồi chuyển đến hàm có thể gọi do người dùng xác định.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
Thuộc tính Enum do người dùng xác định
FFI XLA có thể tự động giải mã các thuộc tính MLIR nguyên vẹn thành các enum do người dùng xác định. Lớp Enum phải có cùng kiểu số nguyên cơ bản và quá trình giải mã phải được đăng ký rõ ràng với XLA FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
Liên kết tất cả các thuộc tính tuỳ chỉnh của cuộc gọi
Bạn có thể truy cập vào tất cả các thuộc tính cuộc gọi tuỳ chỉnh dưới dạng từ điển và chỉ giải mã các thuộc tính cần thiết tại thời gian chạy.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
Thuộc tính Struct do người dùng xác định
FFI XLA có thể giải mã các thuộc tính từ điển thành các cấu trúc do người dùng xác định.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
Trong ví dụ trên, range là một thuộc tính mlir::DictionaryAttr và thay vì truy cập vào các trường từ điển theo tên, bạn có thể tự động giải mã thuộc tính này dưới dạng một cấu trúc C++. Bạn phải đăng ký quá trình giải mã một cách rõ ràng bằng macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (phía sau cảnh, macro này xác định một chuyên môn hoá mẫu trong không gian tên ::xla::ffi, do đó, macro phải được thêm vào không gian tên chung).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
Bạn có thể tải các thuộc tính tuỳ chỉnh từ một từ điển, giống như mọi thuộc tính khác. Trong ví dụ bên dưới, tất cả các thuộc tính cuộc gọi tuỳ chỉnh được giải mã dưới dạng Dictionary và bạn có thể truy cập vào range theo tên.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
Tạo lệnh gọi tuỳ chỉnh trên CPU
Bạn có thể tạo một chỉ dẫn HLO đại diện cho một lệnh gọi tuỳ chỉnh thông qua API máy khách của XLA. Ví dụ: đoạn mã sau đây sử dụng một lệnh gọi tuỳ chỉnh để tính toán A[i] = B[i %
128]+ C[i] trên CPU. (Tất nhiên là bạn có thể và nên làm! – thực hiện việc này với HLO thông thường.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
Tạo lệnh gọi tuỳ chỉnh trên GPU
Việc đăng ký lệnh gọi tuỳ chỉnh GPU bằng XLA FFI gần như giống hệt nhau, điểm khác biệt duy nhất là đối với GPU, bạn cần yêu cầu một luồng nền tảng cơ bản (luồng CUDA hoặc ROCM) để có thể khởi chạy hạt nhân trên thiết bị. Dưới đây là một ví dụ về CUDA thực hiện cùng một phép tính (A[i] = B[i % 128] + C[i]) như mã CPU ở trên.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Trước tiên, hãy lưu ý rằng hàm gọi tuỳ chỉnh GPU vẫn là một hàm được thực thi trên CPU. Hàm CPU do_custom_call chịu trách nhiệm xếp hàng công việc trên GPU. Ở đây, nó khởi chạy một nhân CUDA, nhưng cũng có thể làm việc khác, chẳng hạn như gọi cuBLAS.
Các đối số và kết quả cũng nằm trên máy chủ lưu trữ, đồng thời thành viên dữ liệu chứa một con trỏ đến bộ nhớ thiết bị (tức là GPU). Các vùng đệm được truyền đến trình xử lý lệnh gọi tuỳ chỉnh có hình dạng của các vùng đệm thiết bị cơ bản, vì vậy, lệnh gọi tuỳ chỉnh có thể tính toán các tham số khởi động nhân từ các vùng đệm đó.
Truyền các bộ đến các lệnh gọi tuỳ chỉnh
Hãy xem xét lệnh gọi tuỳ chỉnh sau.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
Trên cả CPU và GPU, một bộ được biểu thị trong bộ nhớ dưới dạng một mảng con trỏ. Khi XLA gọi các lệnh gọi tuỳ chỉnh bằng đối số hoặc kết quả của bộ giá trị, XLA sẽ làm phẳng các đối số hoặc kết quả đó và truyền dưới dạng đối số hoặc kết quả của vùng đệm thông thường.
Đầu ra của bộ dữ liệu dưới dạng vùng đệm tạm thời
Đầu vào bộ dữ liệu cho các lệnh gọi tuỳ chỉnh là một tiện ích, nhưng không thực sự cần thiết. Nếu chúng tôi không hỗ trợ các đầu vào bộ giá trị cho các lệnh gọi tuỳ chỉnh, bạn luôn có thể giải nén các bộ giá trị bằng cách sử dụng get-tuple-element trước khi truyền chúng đến lệnh gọi tuỳ chỉnh.
Mặt khác, đầu ra của bộ dữ liệu cho phép bạn làm những việc mà bạn không thể làm được nếu không có bộ dữ liệu.
Lý do rõ ràng để có các đầu ra của bộ giá trị là các đầu ra của bộ giá trị là cách một lệnh gọi tuỳ chỉnh (hoặc bất kỳ thao tác XLA nào khác) trả về nhiều mảng độc lập.
Tuy nhiên, ít rõ ràng hơn là đầu ra của bộ giá trị cũng là một cách để cung cấp bộ nhớ tạm thời cho lệnh gọi tuỳ chỉnh của bạn. Có, đầu ra có thể biểu thị một vùng đệm tạm thời. Hãy xem xét, một vùng đệm đầu ra có thuộc tính mà thao tác có thể ghi vào đó và có thể đọc từ đó sau khi được ghi vào. Đó chính xác là những gì bạn muốn ở một vùng đệm tạm thời.
Trong ví dụ trên, giả sử chúng ta muốn dùng F32[1024] làm vùng đệm tạm thời.
Sau đó, chúng ta sẽ viết HLO giống như trên và chúng ta sẽ không bao giờ đọc chỉ mục bộ dữ liệu 1 của đầu ra lệnh gọi tuỳ chỉnh.