This document describes how to write and use XLA custom calls using XLA FFI library. Custom call is a mechanism to describe an external "operation" in the HLO module to the XLA compiler (at compile time), and XLA FFI is a mechanism to register implementation of such operations with XLA (at run time). FFI stands for "foreign function interface" and it is a set of C APIs that define a binary interface (ABI) for XLA to call into external code written in other programming languages. XLA provides header-only bindings for XLA FFI written in C++, which hides all the low level details of underlying C APIs from the end user.
JAX + XLA Custom Calls
See JAX documentation for end to end examples of integrating custom calls and XLA FFI with JAX.
XLA FFI Binding
XLA FFI binding is a compile-time specification of the custom call signature:
custom call arguments, attributes and their types, and additional parameters
passed via the execution context (i.e., gpu stream for GPU backend). XLA FFI
finding can be bound to any C++ callable (function pointer, lambda, etc.) with
compatible operator()
signature. Constructed handler decodes XLA FFI call
frame (defined by the stable C API), type check all parameters, and forward
decoded results to the user-defined callback.
XLA FFI binding heavily relies on template metaprogramming to be be able to compile constructed handler to the most efficient machine code. Run time overheads are in order of a couple of nanoseconds for each custom call parameter.
XLA FFI customization points implemented as template specializations, and
users can define how to decode their custom types, i.e., it is possible
to define custom decoding for user-defined enum class
types.
Returning Errors From Custom Calls
Custom call implementations must return xla::ffi::Error
value to signal
success or error to XLA runtime. It is similar to absl::Status
, and has
the same set of error codes. We do not use absl::Status
because it does
not have a stable ABI and it would be unsafe to pass it between dynamically
loaded custom call library, and XLA itself.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
Buffer Arguments And Results
XLA uses destination passing style for results: custom calls (or any other XLA operations for that matter) do not allocate memory for results, and instead write into destinations passed by XLA runtime. XLA uses static buffer assignment, and allocates buffers for all values based on their live ranges at compile time.
Results passed to FFI handlers wrapped into a Result<T>
template, that
has a pointer-like semantics: operator->
gives access to the underlying
parameter.
AnyBuffer
arguments and results gives access to custom call buffer parameters
of any data type. This is useful when custom call has a generic implementation
that works for multiple data types, and custom call implementation does run time
dispatching based on data type. AnyBuffer
gives access to the buffer data
type, dimensions, and a pointer to the buffer itself.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
Constrained Buffer Arguments And Results
Buffer
allows to add constraints on the buffer data type and rank, and they
will be automatically checked by the handler and return an error to XLA runtime,
if run time arguments do not match the FFI handler signature.
// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
Variadic Arguments And Results
If the number of arguments and result can be different in different instances of
a custom call, they can be decoded at run time using RemainingArgs
and
RemainingRets
.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
Variadic arguments and results can be declared after regular arguments and results, however binding regular arguments and results after variadic one is illegal.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Attributes
XLA FFI supports automatic decoding of mlir::DictionaryAttr
passed as a
custom_call
backend_config
into FFI handler arguments.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
In this example custom call has a single buffer argument and two attributes, and XLA FFI can automatically decode them and pass to the user-defined callable.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
User-Defined Enum Attributes
XLA FFI can automatically decode integral MLIR attributes into user-defined enums. Enum class must have the same underlying integral type, and decoding has to be explicitly registered with XLA FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
Binding All Custom Call Attributes
It is possible to get access to all custom call attributes as a dictionary and lazily decode only the attributes that are needed at run time.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
User-defined Struct Attributes
XLA FFI can decode dictionary attributes into user-defined structs.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
In example above range
is an mlir::DictionaryAttr
attribute, and instead
of accessing dictionary fields by name, it can be automatically decoded as
a C++ struct. Decoding has to be explicitly registered with a
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING
macro (behind the scene it defines
a template specialization in ::xla::ffi
namespace, thus macro must be added to
the global namespace).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
Custom attributes can be loaded from a dictionary, just like any other
attribute. In example below, all custom call attributes decoded as a
Dictionary
, and a range
can be accessed by name.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
Create a custom call on CPU
You can create an HLO instruction that represents a custom call via XLA's client
API. For example, the following code uses a custom call to compute A[i] = B[i %
128]+ C[i]
on the CPU. (Of course you could – and should! – do this
with regular HLO.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
Create a custom call on GPU
The GPU custom call registration with XLA FFI is almost identical, the only
difference is that for GPU you need to ask for an underlying platform stream
(CUDA or ROCM stream) to be able to launch kernel on device. Here is a CUDA
example that does the same computation (A[i] = B[i % 128] + C[i]
) as the CPU
code above.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Notice first that the GPU custom call function is still a function executed on
the CPU. The do_custom_call
CPU function is responsible for enqueueing work
on the GPU. Here it launches a CUDA kernel, but it could also do something else,
like call cuBLAS.
Arguments and results also live on the host, and data member contains a pointer to device (i.e. GPU) memory. Buffers passed to custom call handler have the shape of the underlying device buffers, so the custom call can compute kernel launch parameters from them.
Passing tuples to custom calls
Consider the following custom call.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
On both CPU and GPU, a tuple is represented in memory as an array of pointers. When XLA calls custom calls with tuple arguments or results it flattens them and passes as regular buffer arguments or results.
Tuple outputs as temp buffers
Tuple inputs to custom calls are a convenience, but they aren't strictly necessary. If we didn't support tuple inputs to custom calls, you could always unpack the tuples using get-tuple-element before passing them to the custom call.
On the other hand, tuple outputs do let you do things you couldn't otherwise.
The obvious reason to have tuple outputs is that tuple outputs are how a custom call (or any other XLA op) returns multiple independent arrays.
But less obviously, a tuple output is also a way to give your custom call temp memory. Yes, an output can represent a temp buffer. Consider, an output buffer has the property that the op can write to it, and it can read from it after it's been written to. That's exactly what you want from a temp buffer.
In the example above, suppose we wanted to use the F32[1024]
as a temp buffer.
Then we'd write the HLO just as above, and we'd simply never read tuple index 1
of the custom call's output.