Questo documento descrive come scrivere e utilizzare chiamate personalizzate XLA utilizzando la libreria XLA FFI. La chiamata personalizzata è un meccanismo per descrivere un'"operazione" esterna nel modulo HLO al compilatore XLA (in fase di compilazione), mentre XLA FFI è un meccanismo per registrare l'implementazione di queste operazioni con XLA (in fase di runtime). FFI sta per "foreign function interface" ed è un insieme di API C che definiscono un'interfaccia binaria (ABI) per XLA per chiamare codice esterno scritto in altri linguaggi di programmazione. XLA fornisce binding solo di intestazione per XLA FFI scritto in C++, che nasconde all'utente finale tutti i dettagli di basso livello delle API C sottostanti.
JAX + XLA Custom Calls
Consulta la documentazione di JAX per esempi end-to-end di integrazione di chiamate personalizzate e XLA FFI con JAX.
XLA FFI Binding
Il binding FFI di XLA è una specifica in fase di compilazione della firma della chiamata personalizzata:
argomenti, attributi e tipi della chiamata personalizzata e parametri aggiuntivi
trasmessi tramite il contesto di esecuzione (ad es. stream GPU per il backend GPU). Il binding XLA FFI può essere associato a qualsiasi elemento chiamabile C++ (puntatore a funzione, lambda e così via) con una firma operator() compatibile. Il gestore costruito decodifica il frame di chiamata XLA FFI (definito dall'API C stabile), controlla il tipo di tutti i parametri e inoltra i risultati decodificati al callback definito dall'utente.
Il binding FFI di XLA si basa fortemente sulla metaprogrammazione dei modelli per poter compilare l'handler costruito nel codice macchina più efficiente. I sovraccarichi di runtime sono dell'ordine di un paio di nanosecondi per ogni parametro di chiamata personalizzata.
Punti di personalizzazione FFI XLA implementati come specializzazioni di modelli e
gli utenti possono definire come decodificare i propri tipi personalizzati, ovvero è possibile
definire la decodifica personalizzata per i tipi enum class definiti dall'utente.
Restituzione di errori da chiamate personalizzate
Le implementazioni di chiamate personalizzate devono restituire il valore xla::ffi::Error per segnalare
l'esito positivo o l'errore al runtime XLA. È simile a absl::Status e ha
lo stesso insieme di codici di errore. Non utilizziamo absl::Status perché non ha un'ABI stabile e sarebbe pericoloso passarlo tra la libreria di chiamate personalizzate caricata dinamicamente e XLA stessa.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
Argomenti e risultati del buffer
XLA utilizza lo stile di passaggio della destinazione per i risultati: le chiamate personalizzate (o qualsiasi altra operazione XLA) non allocano memoria per i risultati e scrivono invece nelle destinazioni passate dal runtime XLA. XLA utilizza l'assegnazione statica dei buffer e alloca i buffer per tutti i valori in base ai relativi intervalli live in fase di compilazione.
I risultati passati ai gestori FFI vengono inseriti in un modello Result<T>, che
ha una semantica simile a un puntatore: operator-> consente l'accesso al
parametro sottostante.
AnyBuffer e risultati consente di accedere ai parametri del buffer di chiamata personalizzati
di qualsiasi tipo di dati. Ciò è utile quando la chiamata personalizzata ha un'implementazione generica
che funziona per più tipi di dati e l'implementazione della chiamata personalizzata esegue l'invio
in fase di runtime in base al tipo di dati. AnyBuffer fornisce l'accesso al tipo di dati del buffer, alle dimensioni e a un puntatore al buffer stesso.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
Argomenti e risultati del buffer vincolato
Buffer consente di aggiungere vincoli al tipo di dati del buffer e al numero di
dimensioni, che verranno controllati automaticamente dal gestore e restituiranno
un errore al runtime XLA se gli argomenti di runtime non corrispondono alla firma
del gestore FFI.
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
Argomenti e risultati variadici
Se il numero di argomenti e il risultato possono essere diversi in istanze diverse di
una chiamata personalizzata, possono essere decodificati in fase di runtime utilizzando RemainingArgs e
RemainingRets.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
Gli argomenti e i risultati variadici possono essere dichiarati dopo gli argomenti e i risultati regolari, ma l'associazione di argomenti e risultati regolari dopo quelli variadici non è consentita.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
Attributi
XLA FFI supporta la decodifica automatica di mlir::DictionaryAttr passati come
custom_call backend_config negli argomenti del gestore FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
In questo esempio, la chiamata personalizzata ha un singolo argomento buffer e due attributi e XLA FFI può decodificarli automaticamente e passarli alla funzione chiamabile definita dall'utente.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
Attributi enum definiti dall'utente
XLA FFI può decodificare automaticamente gli attributi MLIR integrali in enum definiti dall'utente. La classe enum deve avere lo stesso tipo integrale sottostante e la decodifica deve essere registrata in modo esplicito con XLA FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
Associazione di tutti gli attributi di chiamata personalizzati
È possibile accedere a tutti gli attributi di chiamata personalizzati come dizionario e decodificare in modo differito solo gli attributi necessari in fase di runtime.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
Attributi Struct definiti dall'utente
XLA FFI può decodificare gli attributi del dizionario in struct definiti dall'utente.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
Nell'esempio precedente, range è un attributo mlir::DictionaryAttr e, anziché accedere ai campi del dizionario per nome, può essere decodificato automaticamente come struct C++. La decodifica deve essere registrata in modo esplicito con una
macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (dietro le quinte definisce
una specializzazione del modello nello spazio dei nomi ::xla::ffi, pertanto la macro deve essere aggiunta
allo spazio dei nomi globale).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
Gli attributi personalizzati possono essere caricati da un dizionario, proprio come qualsiasi altro
attributo. Nell'esempio seguente, tutti gli attributi di chiamata personalizzati decodificati come
Dictionary e un range sono accessibili per nome.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
Crea una chiamata personalizzata sulla CPU
Puoi creare un'istruzione HLO che rappresenti una chiamata personalizzata tramite l'API client di XLA. Ad esempio, il seguente codice utilizza una chiamata personalizzata per calcolare A[i] = B[i %
128]+ C[i] sulla CPU. (Certo che puoi, anzi dovresti! – do this
with regular HLO.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
Crea una chiamata personalizzata sulla GPU
La registrazione di chiamate personalizzate della GPU con XLA FFI è quasi identica, l'unica
differenza è che per la GPU devi richiedere un flusso della piattaforma sottostante
(flusso CUDA o ROCM) per poter avviare il kernel sul dispositivo. Ecco un esempio di CUDA
che esegue lo stesso calcolo (A[i] = B[i % 128] + C[i]) del codice
della CPU riportato sopra.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Tieni presente innanzitutto che la funzione di chiamata personalizzata della GPU è ancora una funzione eseguita sulla CPU. La funzione CPU do_custom_call è responsabile dell'accodamento del lavoro
sulla GPU. Qui viene avviato un kernel CUDA, ma potrebbe anche fare altro,
come chiamare cuBLAS.
Gli argomenti e i risultati si trovano anche sull'host e il membro dati contiene un puntatore alla memoria del dispositivo (ad es. GPU). I buffer passati al gestore di chiamate personalizzato hanno la forma dei buffer del dispositivo sottostanti, quindi la chiamata personalizzata può calcolare i parametri di avvio del kernel a partire da questi.
Trasferimento di tuple a chiamate personalizzate
Considera la seguente chiamata personalizzata.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
Sia sulla CPU che sulla GPU, una tupla è rappresentata in memoria come un array di puntatori. Quando XLA chiama chiamate personalizzate con argomenti o risultati di tuple, li appiattisce e li passa come argomenti o risultati di buffer regolari.
Tuple output as temp buffers
Gli input di tuple per le chiamate personalizzate sono una comodità, ma non sono strettamente necessari. Se non supportassimo gli input di tuple per le chiamate personalizzate, potresti sempre decomprimere le tuple utilizzando get-tuple-element prima di passarle alla chiamata personalizzata.
D'altra parte, le uscite delle tuple ti consentono di fare cose che altrimenti non potresti fare.
Il motivo ovvio per cui sono presenti output di tuple è che in questo modo una chiamata personalizzata (o qualsiasi altra operazione XLA) restituisce più array indipendenti.
In modo meno ovvio, un output di tupla è anche un modo per dare alla tua chiamata personalizzata una memoria temporanea. Sì, un output può rappresentare un buffer temporaneo. Considera che un buffer di output ha la proprietà che l'operazione può scriverci e può leggerlo dopo che è stato scritto. È esattamente ciò che vuoi da un buffer temporaneo.
Nell'esempio precedente, supponiamo di voler utilizzare F32[1024] come buffer temporaneo.
Poi scriveremmo l'HLO come sopra e non leggeremmo mai l'indice della tupla 1
dell'output della chiamata personalizzata.