Wywołania niestandardowe XLA

Ten dokument opisuje, jak zapisywać i wykorzystywać niestandardowe wywołania XLA za pomocą biblioteki FFI XLA. Wywołanie niestandardowe to mechanizm opisania zewnętrznej „operacji” w module HLO dla kompilatora XLA (w czasie kompilowania), a XLA FFI to mechanizm rejestrowania implementacji takich operacji za pomocą XLA (w czasie działania). FFI oznacza „interfejs obcych funkcji”. Jest to zestaw interfejsów API typu C, które definiują interfejs binarny (ABI) do wywoływania przez XLA kodu zewnętrznego napisanego w innych językach programowania. XLA udostępnia powiązania tylko nagłówka dla FFI XLA napisane w C++, co ukrywa przed użytkownikiem wszystkie szczegóły niskiego poziomu bazowych interfejsów API C.

Tworzenie niestandardowego wywołania przy procesorze

Możesz utworzyć instrukcję HLO, która reprezentuje niestandardowe wywołanie za pomocą interfejsu API klienta XLA. Na przykład ten kod używa niestandardowego wywołania, aby obliczyć A[i] = B[i % 128]+ C[i] na procesorze. Oczywiście, że można – i warto! – użyj zwykłego HLO.

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Utwórz wywołanie niestandardowe w GPU

Rejestracja niestandardowego wywołania GPU z XLA FFI jest niemal identyczna. Jedyną różnicą jest to, że w przypadku GPU konieczne jest przesłanie prośby o bazowy strumień platformy (strumień CUDA lub ROCM), aby móc uruchomić jądro na urządzeniu. Oto przykład CUDA, który wykonuje te same obliczenia (A[i] = B[i % 128] + C[i]) co powyższy kod procesora.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Po pierwsze, niestandardowa funkcja wywołania GPU jest nadal funkcją wykonywaną na procesorze. Funkcja procesora do_custom_call odpowiada za dodawanie do kolejki pracy na GPU. Tutaj uruchamia jądro CUDA, ale może też robić coś innego, na przykład cuBLAS.

Argumenty i wyniki również znajdują się na hoście, a użytkownik danych zawiera wskaźnik do pamięci urządzenia (GPU). Bufety przekazywane do niestandardowego modułu obsługi połączeń mają kształt bazowych buforów urządzeń, więc wywołanie niestandardowe może obliczać z nich parametry uruchamiania jądra.

Przekazywanie krotek do wywołań niestandardowych

Rozważ poniższe wywołanie niestandardowe.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

W przypadku procesora i GPU krotka jest reprezentowana w pamięci jako tablica wskaźników. Gdy XLA wywołuje niestandardowe wywołania z argumentami kropki lub wyniki, spłaszcza je i przekazuje jako zwykłe argumenty bufora lub wyniki.

Dane wyjściowe krotek jako bufory tymczasowe

Wpisywanie krotek w wywołaniach niestandardowych ułatwia korzystanie z nich, ale nie jest niezbędne. Gdybyśmy nie obsługiwali danych wejściowych krotek w wywołaniach niestandardowych, przed przekazaniem ich do niestandardowego wywołania zawsze można było je rozpakować za pomocą funkcji get-tuple-element.

Z drugiej strony dzięki wyjściom kropek możesz robić rzeczy, które w innym przypadku nie byłyby możliwe.

Oczywistym powodem występowania danych wyjściowych krotek jest to, że dane wyjściowe krotek są sposobem, w jaki niestandardowe wywołanie (lub dowolne inne działanie XLA) zwraca wiele niezależnych tablic.

Jednak w mniej oczywistym przypadku dane wyjściowe kropki są też sposobem na nadanie niestandardowej tymczasowej pamięci wywołania. Tak. Dane wyjściowe mogą reprezentować tymczasowy bufor. Weźmy na przykład bufor danych wyjściowych, który ma właściwość, w której operacja może zapisać dane, i możliwość jej odczytu po zapisywaniu. To jest dokładnie to, czego oczekujesz od bufora tymczasowego.

Załóżmy, że chcemy użyć F32[1024] jako tymczasowego bufora Następnie zapiszemy HLO tak jak powyżej i nigdy nie odczytamy indeksu krotek 1 wyników niestandardowego wywołania.