In diesem Dokument wird beschrieben, wie Sie benutzerdefinierte XLA-Aufrufe mit der XLA-FFI-Bibliothek schreiben und verwenden. Ein benutzerdefinierter Aufruf ist ein Mechanismus, mit dem eine externe „Operation“ im HLO-Modul dem XLA-Compiler (zur Kompilierzeit) beschrieben wird. XLA FFI ist ein Mechanismus, mit dem die Implementierung solcher Vorgänge bei XLA (zur Laufzeit) registriert wird. Finanzinstitute für „Fremdfunktionsschnittstelle“ Es handelt sich um eine Reihe von C-APIs, die ein Binärprogramm definieren, Interface (ABI) für XLA, um externen Code aufzurufen, der in einer anderen Programmierung geschrieben wurde Sprachen. XLA bietet Header-only-Bindungen für XLA FFI, die in C++ geschrieben sind. Dadurch werden alle Low-Level-Details der zugrunde liegenden C-APIs vor dem Endnutzer ausgeblendet.
Benutzerdefinierte JAX- und XLA-Aufrufe
In der JAX-Dokumentation finden Sie End-to-End-Beispiele für die Integration benutzerdefinierter Aufrufe und XLA-FFI in JAX.
XLA-FFI-Bindung
XLA-FFI-Bindung ist eine Kompilierungszeit-Spezifikation der benutzerdefinierten Aufrufsignatur:
Argumente für benutzerdefinierte Aufrufe, Attribute und deren Typen sowie zusätzliche Parameter
wird über den Ausführungskontext übergeben (d.h. GPU-Stream für GPU-Back-End). XLA FFI
Ergebnis kann an jede C++-Callable (Funktionszeiger, Lambda usw.) mit
kompatible operator()
-Signatur. Der erstellte Handler decodiert den XLA-FFI-Aufruf-Frame (definiert durch die stabile C API), führt eine Typprüfung aller Parameter durch und leitet die decodierten Ergebnisse an den benutzerdefinierten Rückruf weiter.
XLA-FFI-Bindung basiert stark auf Vorlagen-Metaprogrammierung, um den erstellten Handler zum effizientesten Maschinencode kompilieren. Laufzeit liegen die Overheads für jeden benutzerdefinierten Aufruf im Bereich von wenigen Nanosekunden, .
XLA-FFI-Anpassungspunkte, die als Vorlagenspezialisierungen implementiert sind, und Nutzer können festlegen, wie ihre benutzerdefinierten Typen decodiert werden. Es ist also möglich, eine benutzerdefinierte Dekodierung für benutzerdefinierte enum class
-Typen zu definieren.
Zurückgegebene Fehler bei benutzerdefinierten Aufrufen
Implementierungen benutzerdefinierter Aufrufe müssen den Wert xla::ffi::Error
zurückgeben, um der XLA-Laufzeit einen Erfolg oder Fehler zu signalisieren. Er ähnelt absl::Status
und hat dieselben Fehlercodes. Wir verwenden absl::Status
nicht, da es keine stabile ABI hat und es nicht sicher wäre, es zwischen der dynamisch geladenen benutzerdefinierten Aufrufbibliothek und XLA selbst zu übergeben.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
Pufferargumente und -ergebnisse
XLA verwendet einen Stil für die Zielübergabe für Ergebnisse: benutzerdefinierte Aufrufe (oder ein anderes XLA) Operationen für diese Angelegenheit) weisen den Ergebnissen keinen Arbeitsspeicher zu. Stattdessen in Ziele schreiben, die von der XLA-Laufzeit übergeben werden. XLA verwendet einen statischen Zwischenspeicher -Zuweisung und weist Puffer für alle Werte basierend auf ihren Live-Bereichen zu der Kompilierungszeit.
Ergebnisse, die an FFI-Handler übergeben werden, die in eine Result<T>
-Vorlage verpackt sind, die eine pointerähnliche Semantik hat: operator->
gewährt Zugriff auf den zugrunde liegenden Parameter.
AnyBuffer
-Argumente und -Ergebnisse ermöglichen Zugriff auf benutzerdefinierte Zwischenspeicherparameter für Aufrufe
eines beliebigen Datentyps. Dies ist nützlich, wenn der benutzerdefinierte Aufruf eine allgemeine Implementierung
die für mehrere Datentypen geeignet ist, und die benutzerdefinierte Anrufimplementierung führt eine Laufzeit durch.
auf Basis des Datentyps. AnyBuffer
gewährt Zugriff auf die Zwischenspeicherdaten
-Typ, Dimensionen und einen Zeiger auf den Zwischenspeicher selbst.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
Begrenzte Pufferargumente und -ergebnisse
Mit Buffer
können Sie Einschränkungen für den Datentyp und den Rang des Buffers hinzufügen. Diese werden vom Handler automatisch geprüft und es wird ein Fehler an die XLA-Laufzeit zurückgegeben, wenn die Laufzeitargumente nicht mit der FFI-Handlersignatur übereinstimmen.
// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
Variadische Argumente und Ergebnisse
Wenn sich die Anzahl der Argumente und das Ergebnis in verschiedenen Instanzen eines benutzerdefinierten Aufrufs unterscheiden können, können sie mit RemainingArgs
und RemainingRets
zur Laufzeit decodiert werden.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
Variadische Argumente und Ergebnisse können nach regulären Argumenten Bindung regulärer Argumente und Ergebnisse nach variadischer illegal sind.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Attribute
XLA FFI unterstützt die automatische Dekodierung von mlir::DictionaryAttr
, die als custom_call
backend_config
an FFI-Handlerargumente übergeben wird.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
In diesem Beispiel hat der benutzerdefinierte Aufruf ein einzelnes Pufferargument und zwei Attribute. XLA FFI kann sie automatisch decodieren und an das benutzerdefinierte Callable übergeben.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
Benutzerdefinierte Enum-Attribute
XLA FFI kann integrale MLIR-Attribute automatisch in benutzerdefinierte Enumerationen decodieren. Die Enum-Klasse muss denselben zugrunde liegenden Ganzzahltyp haben und die Dekodierung muss explizit bei XLA FFI registriert werden.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
Alle benutzerdefinierten Aufrufattribute binden
Es ist möglich, auf alle benutzerdefinierten Aufrufattribute als Dictionary zuzugreifen und nur die Attribute zu decodieren, die zur Laufzeit benötigt werden.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
Benutzerdefinierte Strukturattribute
XLA FFI kann Wörterbuchattribute in benutzerdefinierte Strukturen decodieren.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
Im obigen Beispiel ist range
ein mlir::DictionaryAttr
-Attribut.
auf Wörterbuchfelder nach Name zugreifen, können sie automatisch decodiert werden als
eine C++-Struktur. Die Dekodierung muss explizit mit einem XLA_FFI_REGISTER_STRUCT_ATTR_DECODING
-Makro registriert werden. Im Hintergrund wird eine Vorlagenspezialisierung im ::xla::ffi
-Namespace definiert. Daher muss das Makro dem globalen Namespace hinzugefügt werden.
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
StructMember<int64_t>("i64"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
Benutzerdefinierte Attribute können wie jedes andere Attribut aus einem Wörterbuch geladen werden. Im folgenden Beispiel werden alle benutzerdefinierten Aufrufattribute, die als
Dictionary
und ein range
kann über den Namen aufgerufen werden.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
Benutzerdefinierten Aufruf an die CPU erstellen
Sie können eine HLO-Anweisung erstellen, die einen benutzerdefinierten Aufruf über die Client-API von XLA darstellt. Im folgenden Code wird beispielsweise ein benutzerdefinierter Aufruf zur Berechnung von A[i] = B[i %
128]+ C[i]
auf der CPU verwendet. (Natürlich könnten und sollten Sie das. – das mit der regulären HLO tun.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
Benutzerdefinierten Aufruf auf der GPU erstellen
Die Registrierung von benutzerdefinierten GPU-Aufrufen mit XLA FFI ist fast identisch, die einzige
Der Unterschied besteht darin, dass Sie für die GPU einen zugrunde liegenden Plattformstream anfordern müssen.
(CUDA- oder ROCM-Stream), um den Kernel auf dem Gerät starten zu können. Hier ist ein CUDA
Beispiel, das dieselbe Berechnung (A[i] = B[i % 128] + C[i]
) wie die CPU durchführt
Code oben.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Beachten Sie zuerst, dass die benutzerdefinierte GPU-Aufruffunktion noch eine Funktion ist, die auf
der CPU. Die CPU-Funktion do_custom_call
ist für die Einreihung von Aufgaben in die GPU-Warteschlange verantwortlich. Hier startet er einen CUDA-Kernel, kann aber auch etwas anderes tun,
wie „cuBLAS“ aufrufen.
Argumente und Ergebnisse befinden sich ebenfalls auf dem Host und das Datenmitglied enthält einen Zeiger. zum Arbeitsspeicher des Geräts (z.B. GPU) zu wechseln. Die an den benutzerdefinierten Aufruf-Handler übergebenen Puffer haben die Form der zugrunde liegenden Gerätepuffer, sodass der benutzerdefinierte Aufruf Kernelstartparameter daraus berechnen kann.
Tupel an benutzerdefinierte Aufrufe übergeben
Betrachten Sie den folgenden benutzerdefinierten Aufruf.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
Sowohl auf der CPU als auch auf der GPU wird ein Tupel im Arbeitsspeicher als Array von Pointern dargestellt. Wenn XLA benutzerdefinierte Aufrufe mit Tupelargumenten oder -ergebnissen aufruft, werden diese flachgelegt und als normale Pufferargumente oder -ergebnisse übergeben.
Tupelausgaben als temporäre Puffer
Tupel-Eingaben für benutzerdefinierte Aufrufe sind praktisch, aber nicht genau notwendig ist. Wenn wir keine Tupeleingaben für benutzerdefinierte Aufrufe unterstützen, entpacken Sie die Tupel mithilfe des get-Tupel-Elements, bevor Sie sie an den benutzerdefinierten aufrufen.
Andererseits ermöglichen Ihnen Tupel-Ausgaben Dinge, die Sie sonst nicht tun könnten.
Der offensichtliche Grund für Tupelausgaben ist, dass ein benutzerdefinierter Aufruf (oder eine andere XLA-Operation) mehrere unabhängige Arrays zurückgibt.
Weniger offensichtlich ist jedoch, dass eine Tupelausgabe auch eine Möglichkeit ist, Ihrem benutzerdefinierten Aufruf temporären Speicher zuzuweisen. Ja, eine Ausgabe kann einen temporären Zwischenspeicher darstellen. Ein Ausgabebuffer hat die Eigenschaft, dass die Operation darauf schreiben und nach dem Schreiben daraus lesen kann. Das ist genau das, was Sie sich von einem temporären Zwischenspeicher wünschen.
Angenommen, im obigen Beispiel möchten wir F32[1024]
als temporären Puffer verwenden.
Dann schreiben wir das HLO genau wie oben und lesen nie den Tupelindex 1.
der Ausgabe des benutzerdefinierten Aufrufs.