В этом документе описывается, как писать и использовать пользовательские вызовы XLA с использованием библиотеки XLA FFI. Пользовательский вызов — это механизм описания внешней «операции» в модуле HLO для компилятора XLA (во время компиляции), а XLA FFI — это механизм регистрации реализации таких операций с помощью XLA (во время выполнения). FFI означает «интерфейс внешней функции» и представляет собой набор API-интерфейсов C, которые определяют двоичный интерфейс (ABI) для вызова XLA во внешний код, написанный на других языках программирования. XLA предоставляет привязки только для заголовков для XLA FFI, написанного на C++, что скрывает от конечного пользователя все низкоуровневые детали базовых API C.
Пользовательские вызовы JAX + XLA
См. документацию JAX для сквозных примеров интеграции пользовательских вызовов и XLA FFI с JAX.
Крепление XLA FFI
Привязка XLA FFI — это спецификация пользовательской сигнатуры вызова во время компиляции: пользовательские аргументы вызова, атрибуты и их типы, а также дополнительные параметры, передаваемые через контекст выполнения (т. е. поток графического процессора для серверной части графического процессора). Привязка XLA FFI может быть привязана к любому вызываемому объекту C++ (указателю функции, лямбда-выражению и т. д.) с совместимой сигнатурой operator()
. Созданный обработчик декодирует кадр вызова XLA FFI (определенный стабильным C API), проверяет тип всех параметров и пересылает декодированные результаты в определяемый пользователем обратный вызов.
Привязка XLA FFI во многом зависит от метапрограммирования шаблонов, позволяющего скомпилировать построенный обработчик в наиболее эффективный машинный код. Затраты времени выполнения составляют порядка пары наносекунд для каждого пользовательского параметра вызова.
Точки настройки XLA FFI реализованы как специализации шаблона, и пользователи могут определять, как декодировать свои пользовательские типы, т. е. можно определить собственное декодирование для определяемых пользователем типов enum class
.
Возврат ошибок из пользовательских вызовов
Реализации пользовательских вызовов должны возвращать значение xla::ffi::Error
, чтобы сигнализировать об успехе или ошибке во время выполнения XLA. Он похож на absl::Status
и имеет тот же набор кодов ошибок. Мы не используем absl::Status
, поскольку у него нет стабильного ABI, и было бы небезопасно передавать его между динамически загружаемой пользовательской библиотекой вызовов и самим XLA.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
Буферные аргументы и результаты
XLA использует стиль передачи результатов по назначению: пользовательские вызовы (или любые другие операции XLA в этом отношении) не выделяют память для результатов, а вместо этого записывают в места назначения, передаваемые средой выполнения XLA. XLA использует статическое назначение буфера и выделяет буферы для всех значений на основе их текущих диапазонов во время компиляции.
Результаты передаются обработчикам FFI, заключенным в шаблон Result<T>
, который имеет семантику, подобную указателю: operator->
предоставляет доступ к базовому параметру.
Аргументы и результаты AnyBuffer
предоставляют доступ к пользовательским параметрам буфера вызовов любого типа данных. Это полезно, когда пользовательский вызов имеет общую реализацию, которая работает для нескольких типов данных, а реализация пользовательского вызова выполняет диспетчеризацию во время выполнения на основе типа данных. AnyBuffer
предоставляет доступ к типу данных буфера, размерам и указателю на сам буфер.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
Аргументы и результаты ограниченного буфера
Buffer
позволяет добавлять ограничения на тип и ранг данных буфера, и они будут автоматически проверяться обработчиком и возвращать ошибку во время выполнения XLA, если аргументы времени выполнения не соответствуют сигнатуре обработчика FFI.
// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
Вариатические аргументы и результаты
Если количество аргументов и результат могут различаться в разных экземплярах пользовательского вызова, их можно декодировать во время выполнения с помощью RemainingArgs
и RemainingRets
.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
Вариативные аргументы и результаты могут быть объявлены после обычных аргументов и результатов, однако привязка регулярных аргументов и результатов после вариативного аргумента незаконна.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Атрибуты
XLA FFI поддерживает автоматическое декодирование mlir::DictionaryAttr
переданного как custom_call
backend_config
в аргументы обработчика FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
В этом примере пользовательский вызов имеет один аргумент буфера и два атрибута, и XLA FFI может автоматически декодировать их и передать определяемому пользователем вызываемому объекту.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
Определяемые пользователем атрибуты перечисления
XLA FFI может автоматически декодировать встроенные атрибуты MLIR в определяемые пользователем перечисления. Класс Enum должен иметь тот же базовый целочисленный тип, а декодирование должно быть явно зарегистрировано в XLA FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
Привязка всех пользовательских атрибутов вызова
Можно получить доступ ко всем пользовательским атрибутам вызовов в виде словаря и лениво декодировать только те атрибуты, которые необходимы во время выполнения.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
Пользовательские атрибуты структуры
XLA FFI может декодировать атрибуты словаря в определяемые пользователем структуры.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
В приведенном выше примере range
представляет собой атрибут mlir::DictionaryAttr
, и вместо доступа к полям словаря по имени его можно автоматически декодировать как структуру C++. Декодирование должно быть явно зарегистрировано с помощью макроса XLA_FFI_REGISTER_STRUCT_ATTR_DECODING
(за кадром оно определяет специализацию шаблона в пространстве имен ::xla::ffi
, поэтому макрос необходимо добавить в глобальное пространство имен).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
Пользовательские атрибуты можно загрузить из словаря, как и любой другой атрибут. В приведенном ниже примере все пользовательские атрибуты вызова декодируются как Dictionary
, а доступ к range
можно получить по имени.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
Создать собственный вызов на ЦП
Вы можете создать инструкцию HLO, которая представляет собой пользовательский вызов через клиентский API XLA. Например, следующий код использует специальный вызов для вычисления A[i] = B[i % 128]+ C[i]
на ЦП. (Конечно, вы можете – и должны! – сделать это с помощью обычного HLO.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
Создайте собственный вызов на графическом процессоре
Пользовательская регистрация вызовов графического процессора с помощью XLA FFI практически идентична, с той лишь разницей, что для графического процессора вам необходимо запросить базовый поток платформы (поток CUDA или ROCM), чтобы иметь возможность запускать ядро на устройстве. Вот пример CUDA, который выполняет те же вычисления ( A[i] = B[i % 128] + C[i]
), что и приведенный выше код ЦП.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Прежде всего обратите внимание, что пользовательская функция вызова графического процессора по-прежнему является функцией, выполняемой на ЦП . Функция CPU do_custom_call
отвечает за постановку в очередь работы на графическом процессоре. Здесь он запускает ядро CUDA, но может делать и что-то еще, например вызывать cuBLAS.
Аргументы и результаты также находятся на хосте, а элемент данных содержит указатель на память устройства (т. е. графического процессора). Буферы, передаваемые пользовательскому обработчику вызовов, имеют форму базовых буферов устройства, поэтому пользовательский вызов может вычислять из них параметры запуска ядра.
Передача кортежей в пользовательские вызовы
Рассмотрим следующий пользовательский вызов.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
Как на процессоре, так и на графическом процессоре кортеж представлен в памяти как массив указателей. Когда XLA вызывает пользовательские вызовы с аргументами или результатами кортежа, он выравнивает их и передает как обычные аргументы или результаты буфера.
Кортеж выводится как временные буферы
Ввод кортежей в пользовательские вызовы удобен, но не является строго необходимым. Если бы мы не поддерживали ввод кортежей для пользовательских вызовов, вы всегда могли бы распаковать кортежи с помощью get-tuple-element перед передачей их в пользовательский вызов.
С другой стороны, выходные данные кортежей позволяют вам делать то, что иначе вы не смогли бы.
Очевидная причина использования выходных данных кортежа заключается в том, что выходные данные кортежа — это то, как пользовательский вызов (или любая другая операция XLA) возвращает несколько независимых массивов.
Но менее очевидно, что вывод кортежа также является способом выделения временной памяти для вашего пользовательского вызова. Да, выходные данные могут представлять собой временный буфер. Учтите, что выходной буфер имеет свойство, позволяющее оператору записывать в него данные и читать из него после того, как в него была записана информация. Это именно то, что вы хотите от временного буфера.
Предположим, в приведенном выше примере мы хотим использовать F32[1024]
в качестве временного буфера. Тогда мы напишем HLO так же, как указано выше, и просто никогда не будем читать индекс кортежа 1 вывода пользовательского вызова.