XLA Custom Calls

W tym dokumencie opisujemy, jak pisać i używać niestandardowych wywołań XLA za pomocą biblioteki XLA FFI. Niestandardowe wywołanie to mechanizm opisywania zewnętrznej „operacji” w module HLO kompilatorowi XLA (w czasie kompilacji), a XLA FFI to mechanizm rejestrowania implementacji takich operacji w XLA (w czasie działania). FFI to skrót od „foreign function interface” (interfejs funkcji zewnętrznych). Jest to zestaw interfejsów API w języku C, które definiują interfejs binarny (ABI) dla XLA, umożliwiający wywoływanie zewnętrznego kodu napisanego w innych językach programowania. XLA udostępnia wiązania tylko w plikach nagłówkowych dla XLA FFI napisanych w C++, które ukrywają przed użytkownikiem wszystkie szczegóły niskiego poziomu bazowych interfejsów API w C.

JAX + XLA Custom Calls

W dokumentacji JAX znajdziesz przykłady kompleksowej integracji wywołań niestandardowych i XLA FFI z JAX.

Powiązanie XLA FFI

Powiązanie XLA FFI to specyfikacja sygnatury wywołania niestandardowego w czasie kompilacji: argumenty wywołania niestandardowego, atrybuty i ich typy oraz dodatkowe parametry przekazywane przez kontekst wykonania (np. strumień GPU w przypadku backendu GPU). Powiązanie XLA FFI można powiązać z dowolnym wywoływalnym elementem C++ (wskaźnik funkcji, lambda itp.) o zgodnym operator(). Skonstruowany moduł obsługi dekoduje ramkę wywołania XLA FFI (zdefiniowaną przez stabilny interfejs C API), sprawdza typ wszystkich parametrów i przekazuje zdekodowane wyniki do zdefiniowanego przez użytkownika wywołania zwrotnego.

Powiązanie XLA FFI w dużym stopniu opiera się na metaprogramowaniu szablonów, aby móc skompilować utworzony moduł obsługi do najbardziej wydajnego kodu maszynowego. Narzut czasu działania wynosi kilka nanosekund na każdy parametr wywołania niestandardowego.

Punkty dostosowywania XLA FFI zaimplementowane jako specjalizacje szablonów, a użytkownicy mogą określać sposób dekodowania typów niestandardowych, tzn. można zdefiniować niestandardowe dekodowanie typów enum class zdefiniowanych przez użytkownika.

Zwracanie błędów z połączeń niestandardowych

Implementacje wywołań niestandardowych muszą zwracać wartość xla::ffi::Error, aby sygnalizować środowisku wykonawczemu XLA powodzenie lub błąd. Jest podobny do absl::Status i ma ten sam zestaw kodów błędów. Nie używamy absl::Status, ponieważ nie ma stabilnego interfejsu ABI i przekazywanie go między dynamicznie wczytywaną biblioteką wywołań niestandardowych a samą biblioteką XLA byłoby niebezpieczne.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumenty i wyniki bufora

XLA używa stylu przekazywania do miejsca docelowego w przypadku wyników: wywołania niestandardowe (lub inne operacje XLA) nie przydzielają pamięci na wyniki, ale zapisują je w miejscach docelowych przekazywanych przez środowisko wykonawcze XLA. XLA używa statycznego przypisywania buforów i przydziela bufory dla wszystkich wartości na podstawie ich zakresów na żywo w czasie kompilacji.

Wyniki przekazywane do modułów obsługi FFI są opakowane w szablon Result<T>, który ma semantykę podobną do wskaźnika: operator-> zapewnia dostęp do bazowego parametru.

AnyBuffer argumenty i wyniki umożliwiają dostęp do niestandardowych parametrów bufora wywołań dowolnego typu danych. Jest to przydatne, gdy wywołanie niestandardowe ma ogólną implementację, która działa w przypadku wielu typów danych, a implementacja wywołania niestandardowego wykonuje wysyłanie w czasie działania na podstawie typu danych. AnyBuffer zapewnia dostęp do danych bufora, typu, wymiarów i wskaźnika do samego bufora.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumenty i wyniki bufora z ograniczeniami

Buffer umożliwia dodawanie ograniczeń dotyczących typu danych bufora i liczby wymiarów. Będą one automatycznie sprawdzane przez moduł obsługi i w razie niezgodności argumentów czasu działania z sygnaturą modułu obsługi FFI zwracać błąd do środowiska wykonawczego XLA.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumenty i wyniki o zmiennej liczbie

Jeśli liczba argumentów i wynik mogą się różnić w różnych instancjach wywołania niestandardowego, można je dekodować w czasie działania za pomocą funkcji RemainingArgsRemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Argumenty i wyniki zmienne mogą być deklarowane po argumentach i wynikach zwykłych, ale wiązanie argumentów i wyników zwykłych po argumentach i wynikach zmiennych jest niedozwolone.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atrybuty

XLA FFI obsługuje automatyczne dekodowanie mlir::DictionaryAttr przekazywanych jako custom_call backend_config do argumentów procedury obsługi FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

W tym przykładzie wywołanie niestandardowe ma 1 argument bufora i 2 atrybuty, a XLA FFI może automatycznie je dekodować i przekazywać do wywoływalnego zdefiniowanego przez użytkownika.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atrybuty wyliczeniowe zdefiniowane przez użytkownika

XLA FFI może automatycznie dekodować atrybuty MLIR w postaci liczb całkowitych na zdefiniowane przez użytkownika wyliczenia. Klasa wyliczeniowa musi mieć ten sam podstawowy typ całkowity, a dekodowanie musi być jawnie zarejestrowane w XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Powiązywanie wszystkich atrybutów połączeń niestandardowych

Możesz uzyskać dostęp do wszystkich atrybutów połączenia niestandardowego w formie słownika i dekodować tylko te atrybuty, które są potrzebne w czasie działania programu.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atrybuty struktury zdefiniowane przez użytkownika

XLA FFI może dekodować atrybuty słownika do zdefiniowanych przez użytkownika struktur.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

W powyższym przykładzie range jest atrybutem mlir::DictionaryAttr, a zamiast uzyskiwać dostęp do pól słownika według nazwy, można go automatycznie dekodować jako strukturę C++. Dekodowanie musi być jawnie zarejestrowane za pomocą makra XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (w tle definiuje specjalizację szablonu w przestrzeni nazw ::xla::ffi, dlatego makro musi zostać dodane do globalnej przestrzeni nazw).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Atrybuty niestandardowe można wczytywać ze słownika, tak jak każdy inny atrybut. W przykładzie poniżej wszystkie atrybuty połączenia niestandardowego zdekodowane jako Dictionaryrange są dostępne według nazwy.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Tworzenie niestandardowego wywołania na procesorze

Możesz utworzyć instrukcję HLO, która reprezentuje wywołanie niestandardowe za pomocą interfejsu API klienta XLA. Na przykład ten kod używa wywołania niestandardowego do obliczenia A[i] = B[i % 128]+ C[i] na procesorze. (Oczywiście możesz i powinieneś. – zrób to za pomocą zwykłego HLO).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Tworzenie niestandardowego wywołania na procesorze graficznym

Rejestracja niestandardowego wywołania GPU za pomocą XLA FFI jest prawie identyczna. Jedyna różnica polega na tym, że w przypadku GPU musisz poprosić o strumień platformy bazowej (strumień CUDA lub ROCM), aby móc uruchomić jądro na urządzeniu. Oto przykład kodu CUDA, który wykonuje te same obliczenia (A[i] = B[i % 128] + C[i]) co kod CPU powyżej.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Zwróć uwagę, że funkcja wywołania niestandardowego GPU jest nadal funkcją wykonywaną na procesorze. Funkcja procesora do_custom_call odpowiada za umieszczanie zadań w kolejce na GPU. W tym przypadku uruchamia on jądro CUDA, ale może też wykonywać inne czynności, np. wywoływać cuBLAS.

Argumenty i wyniki również znajdują się na hoście, a element danych zawiera wskaźnik do pamięci urządzenia (czyli GPU). Bufory przekazywane do niestandardowego modułu obsługi wywołań mają kształt buforów urządzenia, dzięki czemu niestandardowe wywołanie może obliczać na ich podstawie parametry uruchamiania jądra.

Przekazywanie krotek do wywołań niestandardowych

Rozważ użycie tego niestandardowego wywołania:

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Zarówno w przypadku procesora, jak i GPU krotka jest reprezentowana w pamięci jako tablica wskaźników. Gdy XLA wywołuje niestandardowe wywołania z argumentami lub wynikami w postaci krotki, spłaszcza je i przekazuje jako zwykłe argumenty lub wyniki bufora.

Krotki jako tymczasowe bufory wyjściowe

Krotki jako dane wejściowe do wywołań niestandardowych są wygodne, ale nie są bezwzględnie konieczne. Gdybyśmy nie obsługiwali krotek jako danych wejściowych w przypadku wywołań niestandardowych, zawsze można by rozpakować krotki za pomocą funkcji get-tuple-element przed przekazaniem ich do wywołania niestandardowego.

Z drugiej strony krotki outputs umożliwiają wykonywanie czynności, które w inny sposób nie byłyby możliwe.

Oczywistym powodem stosowania danych wyjściowych w postaci krotki jest to, że w ten sposób niestandardowe wywołanie (lub dowolna inna operacja XLA) zwraca wiele niezależnych tablic.

Mniej oczywiste jest to, że dane wyjściowe w postaci krotki to także sposób na zapewnienie niestandardowemu wywołaniu tymczasowej pamięci. Tak, dane wyjściowe mogą reprezentować tymczasowy bufor. Rozważmy bufor wyjściowy, do którego operacja może zapisywać dane i z którego może je odczytywać po zapisaniu. Właśnie tego oczekujesz od bufora tymczasowego.

Załóżmy, że w przykładzie powyżej chcemy użyć F32[1024] jako bufora tymczasowego. Następnie napiszemy HLO tak jak powyżej i po prostu nigdy nie odczytamy indeksu krotki 1 z danych wyjściowych wywołania niestandardowego.