XLA-Anfragen

In diesem Dokument wird beschrieben, wie Sie benutzerdefinierte XLA-Aufrufe mit der XLA-FFI-Bibliothek schreiben und verwenden. Ein benutzerdefinierter Aufruf ist ein Mechanismus, um dem XLA-Compiler (zur Kompilierzeit) einen externen „Vorgang“ im HLO-Modul zu beschreiben. XLA FFI ist ein Mechanismus, um die Implementierung solcher Vorgänge bei XLA zu registrieren (zur Laufzeit). FFI steht für „foreign function interface“ (Schnittstelle für fremde Funktionen). Es handelt sich um eine Reihe von C-APIs, die eine binäre Schnittstelle (ABI) für XLA definieren, um externen Code aufzurufen, der in anderen Programmiersprachen geschrieben wurde. XLA bietet Header-only-Bindungen für XLA FFI, die in C++ geschrieben sind. Dadurch werden alle Details der zugrunde liegenden C-APIs auf niedriger Ebene vor dem Endnutzer verborgen.

JAX + XLA Custom Calls

JAX-Dokumentation mit End-to-End-Beispielen für die Integration benutzerdefinierter Aufrufe und XLA FFI in JAX.

XLA-FFI-Bindung

Die XLA-FFI-Bindung ist eine Spezifikation der benutzerdefinierten Aufrufsignatur zur Kompilierungszeit: Argumente, Attribute und ihre Typen sowie zusätzliche Parameter, die über den Ausführungskontext übergeben werden (z.B. GPU-Stream für das GPU-Backend). XLA FFI-Bindungen können an jedes aufrufbare C++-Element (Funktionszeiger, Lambda usw.) mit einer kompatiblen operator()-Signatur gebunden werden. Der erstellte Handler decodiert den XLA-FFI-Aufruf-Frame (definiert durch die stabile C-API), führt eine Typüberprüfung aller Parameter durch und leitet die decodierten Ergebnisse an den benutzerdefinierten Callback weiter.

Die XLA-FFI-Bindung basiert stark auf der Template-Metaprogrammierung, um den erstellten Handler in den effizientesten Maschinencode zu kompilieren. Die Laufzeit-Overheads liegen bei einigen Nanosekunden für jeden benutzerdefinierten Anrufparameter.

XLA-FFI-Anpassungspunkte, die als Vorlagenspezialisierungen implementiert werden, und Nutzer können definieren, wie ihre benutzerdefinierten Typen decodiert werden. Es ist also möglich, benutzerdefiniertes Decodieren für benutzerdefinierte enum class-Typen zu definieren.

Fehler bei benutzerdefinierten Aufrufen zurückgeben

Benutzerdefinierte Aufrufe müssen den Wert xla::ffi::Error zurückgeben, um der XLA-Laufzeit Erfolg oder Fehler zu signalisieren. Er ähnelt absl::Status und hat dieselben Fehlercodes. Wir verwenden absl::Status nicht, da es keine stabile ABI hat und es unsicher wäre, es zwischen der dynamisch geladenen benutzerdefinierten Aufrufbibliothek und XLA selbst zu übergeben.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumente und Ergebnisse puffern

XLA verwendet den Stil „Ziel übergeben“ für Ergebnisse: Bei benutzerdefinierten Aufrufen (oder anderen XLA-Vorgängen) wird kein Speicherplatz für Ergebnisse zugewiesen. Stattdessen wird in Ziele geschrieben, die von der XLA-Laufzeit übergeben werden. XLA verwendet die statische Pufferzuweisung und weist Puffer für alle Werte basierend auf ihren Live-Bereichen zur Kompilierzeit zu.

Ergebnisse, die an FFI-Handler übergeben werden, die in eine Result<T>-Vorlage eingebunden sind, die eine pointerähnliche Semantik hat: operator-> ermöglicht den Zugriff auf den zugrunde liegenden Parameter.

AnyBuffer-Argumente und -Ergebnisse ermöglichen den Zugriff auf benutzerdefinierte Callbuffer-Parameter beliebigen Datentyps. Das ist nützlich, wenn der benutzerdefinierte Aufruf eine generische Implementierung hat, die für mehrere Datentypen funktioniert, und die Implementierung des benutzerdefinierten Aufrufs das Laufzeit-Dispatching basierend auf dem Datentyp ausführt. AnyBuffer bietet Zugriff auf den Pufferdatentyp, die Dimensionen und einen Zeiger auf den Puffer selbst.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Eingeschränkte Pufferargumente und Ergebnisse

Mit Buffer können Einschränkungen für den Pufferdatentyp und die Anzahl der Dimensionen hinzugefügt werden. Diese werden automatisch vom Handler geprüft und es wird ein Fehler an die XLA-Laufzeit zurückgegeben, wenn die Laufzeitargumente nicht mit der FFI-Handlersignatur übereinstimmen.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Variadische Argumente und Ergebnisse

Wenn sich die Anzahl der Argumente und das Ergebnis in verschiedenen Instanzen eines benutzerdefinierten Aufrufs unterscheiden können, können sie zur Laufzeit mit RemainingArgs und RemainingRets decodiert werden.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Variadische Argumente und Ergebnisse können nach regulären Argumenten und Ergebnissen deklariert werden. Das Binden regulärer Argumente und Ergebnisse nach variadischen ist jedoch unzulässig.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Attribute

XLA FFI unterstützt die automatische Dekodierung von mlir::DictionaryAttr, die als custom_call backend_config an FFI-Handlerargumente übergeben werden.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

In diesem Beispiel hat der benutzerdefinierte Aufruf ein einzelnes Pufferargument und zwei Attribute. XLA FFI kann sie automatisch decodieren und an das benutzerdefinierte Callable übergeben.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Benutzerdefinierte Enum-Attribute

Mit XLA FFI können integrale MLIR-Attribute automatisch in benutzerdefinierte Enums decodiert werden. Die Enum-Klasse muss denselben zugrunde liegenden integralen Typ haben und die Dekodierung muss explizit bei XLA FFI registriert werden.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Alle benutzerdefinierten Anrufattribute binden

Sie können auf alle benutzerdefinierten Anrufattribute als Dictionary zugreifen und nur die Attribute, die zur Laufzeit benötigt werden, verzögert decodieren.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Benutzerdefinierte Strukturattribute

Mit XLA FFI können Wörterbuchattribute in benutzerdefinierte Strukturen decodiert werden.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Im obigen Beispiel ist range ein mlir::DictionaryAttr-Attribut. Anstatt über den Namen auf die Wörterbuchfelder zuzugreifen, kann es automatisch als C++-Struct decodiert werden. Die Dekodierung muss explizit mit einem XLA_FFI_REGISTER_STRUCT_ATTR_DECODING-Makro registriert werden. Im Hintergrund wird dadurch eine Vorlagenspezialisierung im ::xla::ffi-Namespace definiert. Das Makro muss also dem globalen Namespace hinzugefügt werden.

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Benutzerdefinierte Attribute können wie jedes andere Attribut aus einem Dictionary geladen werden. Im Beispiel unten werden alle benutzerdefinierten Anrufattribute als Dictionary decodiert und auf range kann über den Namen zugegriffen werden.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Benutzerdefinierten Aufruf auf der CPU erstellen

Sie können eine HLO-Anweisung erstellen, die einen benutzerdefinierten Aufruf über die Client-API von XLA darstellt. Im folgenden Code wird beispielsweise ein benutzerdefinierter Aufruf verwendet, um A[i] = B[i % 128]+ C[i] auf der CPU zu berechnen. Natürlich können und sollten Sie das tun. – do this with regular HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Benutzerdefinierten Aufruf auf der GPU erstellen

Die Registrierung benutzerdefinierter GPU-Aufrufe mit XLA FFI ist fast identisch. Der einzige Unterschied besteht darin, dass Sie für die GPU einen zugrunde liegenden Plattformstream (CUDA- oder ROCM-Stream) anfordern müssen, um den Kernel auf dem Gerät zu starten. Hier ist ein CUDA-Beispiel, das dieselbe Berechnung (A[i] = B[i % 128] + C[i]) wie der CPU-Code oben ausführt.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Beachten Sie, dass die benutzerdefinierte GPU-Aufruffunktion weiterhin eine Funktion ist, die auf der CPU ausgeführt wird. Die do_custom_call-CPU-Funktion ist für das Einreihen von Aufgaben in die GPU verantwortlich. Hier wird ein CUDA-Kernel gestartet, aber es könnte auch etwas anderes passieren, z. B. ein Aufruf von cuBLAS.

Argumente und Ergebnisse befinden sich ebenfalls auf dem Host und das Datenmember enthält einen Zeiger auf den Gerätespeicher (d.h. den GPU-Speicher). Puffer, die an den benutzerdefinierten Call-Handler übergeben werden, haben die Form der zugrunde liegenden Gerätepuffer. Der benutzerdefinierte Call kann also Kernel-Startparameter daraus berechnen.

Tupel an benutzerdefinierte Aufrufe übergeben

Sehen Sie sich den folgenden benutzerdefinierten Aufruf an.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Sowohl auf der CPU als auch auf der GPU wird ein Tupel im Arbeitsspeicher als Array von Zeigern dargestellt. Wenn XLA benutzerdefinierte Aufrufe mit Tupelargumenten oder -ergebnissen aufruft, werden diese vereinfacht und als reguläre Pufferargumente oder -ergebnisse übergeben.

Tupelausgaben als temporäre Puffer

Tupel-Eingaben für benutzerdefinierte Aufrufe sind praktisch, aber nicht unbedingt erforderlich. Wenn wir keine Tupel-Eingaben für benutzerdefinierte Aufrufe unterstützt hätten, hätten Sie die Tupel immer mit „get-tuple-element“ entpacken können, bevor Sie sie an den benutzerdefinierten Aufruf übergeben.

Mit Tupelausgaben können Sie dagegen Dinge tun, die sonst nicht möglich wären.

Der offensichtliche Grund für Tupelausgaben ist, dass ein benutzerdefinierter Aufruf (oder ein anderer XLA-Vorgang) mehrere unabhängige Arrays zurückgibt.

Weniger offensichtlich ist jedoch, dass eine Tupelausgabe auch eine Möglichkeit ist, dem benutzerdefinierten Aufruf temporären Speicher zuzuweisen. Ja, eine Ausgabe kann einen temporären Puffer darstellen. Ein Ausgabepuffer hat die Eigenschaft, dass der Vorgang in ihn schreiben und aus ihm lesen kann, nachdem in ihn geschrieben wurde. Genau das ist es, was Sie von einem temporären Puffer erwarten.

Angenommen, wir möchten im obigen Beispiel F32[1024] als temporären Puffer verwenden. Dann würden wir das HLO wie oben schreiben und einfach nie den Tupelindex 1 der Ausgabe des benutzerdefinierten Aufrufs lesen.