Wywołania niestandardowe XLA

Ten dokument opisuje, jak zapisywać i używać niestandardowych wywołań XLA. Wywołania niestandardowe pozwalają wywoływać kod napisany w języku programowania, np. C++ czy CUDA, z programu XLA.

Utwórz niestandardowe wywołanie procesora

Możesz utworzyć instrukcję HLO reprezentującą niestandardowe wywołanie za pomocą interfejsu API klienta XLA. Na przykład ten kod używa niestandardowego wywołania do obliczania A[i] = B[i % 128]+ C[i] na procesorze. (Oczywiście, że można – i należy! – użyj zwykłego typu HLO).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Zwróć uwagę, że funkcja do_custom_call musi znać wymiary buforów, nad którymi działa. W tym przykładzie zakodowaliśmy na stałe rozmiary 128 i 2048. Jeśli nie chcesz tego robić, możesz przekazać wymiary jako parametry do wywołania.

Utwórz niestandardowe wywołanie GPU

Niestandardowa struktura wywołań GPU różni się nieco od platformy procesora. Oto przykład CUDA, który wykonuje te same obliczenia (A[i] = B[i % 128] + C[i]) co powyższy kod procesora.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Zwróć uwagę, że niestandardowa funkcja wywołania GPU wciąż jest funkcją wykonywaną na procesorze. Funkcja CPU do_custom_call odpowiada za kolejkowanie zadań na GPU. Tutaj uruchamia jądro CUDA, ale może też robić coś innego, na przykład wywoływać cuBLAS.

buffers to tablica wskaźników żyjących na hoście, a każdy element, który zawiera, wskazuje pamięć urządzenia (GPU). Parametry są na pierwszym miejscu, a po nim wartość wyjściowa. Znacząco różni się od konwencji wywoływania procesora, która ma 2 parametry: ins i out. Konwencja wywoływania GPU umożliwia sprawną obsługę danych wejściowych i wyjściowych w kształcie krotek.

Tak jak w przykładzie z procesorem, zakodowaliśmy na stałe rozmiary bufora wejściowego i wyjściowego w naszym niestandardowym wywołaniu. Jednak w przeciwieństwie do procesorów, przekazywanie rozmiarów bufora jako argumentów do niestandardowego wywołania nie działa dobrze. Zwykle potrzebujemy rozmiarów bufora dostępnych na procesorze (np. przy uruchamianiu jądra trzeba znać wymiary bloku/siatki). Gdyby jednak przekazywać rozmiary buforów jako argumenty do niestandardowego wywołania, ich wartości znajdowałyby się w pamięci GPU. Musielibyśmy na początku operacji wykonać kosztowny synchroniczny protokół memcpy między urządzeniami, aby odczytać rozmiary.

Aby Ci to umożliwić, udostępniamy parametr opaque. Podczas tworzenia wywołania niestandardowego możesz ustawić w tym polu dowolny ciąg bajtów:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Ponieważ xla::Shape zawiera reprezentację bufora protokołu, możesz zapisać to zserializowane proto w elemencie opaque i poddać go deserializacji w ramach niestandardowego wywołania GPU. Pamiętaj jednak, że chociaż xla::ShapeProto nie zmienia się zbyt często, zmienia się. Sprawdź dziennik Git, aby zobaczyć, jak zmieniało się ono w przeszłości.

Sygnalizowanie błędu

Jeśli wywołanie niestandardowe napotka błąd, możesz zasygnalizować błąd do środowiska wykonawczego XLA (zamiast np. awarii czy zwrócenia błędu w buforach danych wyjściowych), używając tego podpisu dla swojej funkcji:

Na procesorze:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Możesz zasygnalizować brak połączenia za pomocą funkcji XlaCustomCallStatusSetFailure, np.:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Możesz też używać właściwości XlaCustomCallStatusSetSuccess, aby wskazać powodzenie, ale domyślnie element XlaCustomCallStatus jest w stanie sukcesu, więc jego całkowite zignorowanie również będzie oznaczać powodzenie.

Jeśli używasz niestandardowych funkcji wywołań z tym podpisem, musisz utworzyć odpowiednią operację custom-call z ustawioną odpowiednią wersją interfejsu API, np.:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

W przypadku błędu nie są używane żadne niestandardowe dane wyjściowe wywołań; środowisko wykonawcze XLA zakończy obliczenia. Obliczenie poziomu HLO nie da się przywrócić z błędu (np. przez wychwycenie i objęcie go).

Przekazywanie krotek do wywołań niestandardowych

Rozważ poniższe wywołanie niestandardowe.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Zarówno w przypadku procesora, jak i GPU, krotka jest reprezentowana w pamięci jako tablica wskaźników. W pseudokodzie C++ powyższy parametr 0 jest układany w ten sposób.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Mimo że reprezentacja krotek w pamięci jest taka sama w przypadku procesora i GPU, są one obsługiwane inaczej w konwencjach wywołań niestandardowych dotyczących CPU i GPU.

Dane wyjściowe krotek jako bufory tymczasowe

Dane wejściowe krotek w wywołaniach niestandardowych to wygoda, ale nie są niezbędne. Gdybyśmy nie obsługiwali danych wejściowych krotek w wywołaniach niestandardowych, przed przekazaniem ich do niestandardowego wywołania można było rozpakować krotki za pomocą metody get-tuple.

Z drugiej strony dzięki wyjściom krotek możesz robić rzeczy, które inaczej nie byłyby możliwe.

Oczywistym powodem, dla którego dane wyjściowe krotek są związane, są to, że wywołanie niestandardowe (lub dowolne inne działanie XLA) zwraca wiele niezależnych tablic.

Mniej oczywiste, że dane wyjściowe krotek są także sposobem na nadanie niestandardowej temperatury wywołania. Tak. Dane wyjściowe mogą reprezentować bufor tymczasowy. Załóżmy, że bufor danych wyjściowych ma właściwość, którą operacja może zapisać i może ją z niego odczytać po tym, jak to zrobisz. To dokładnie to, czego oczekujesz od bufora tymczasowego.

Załóżmy, że chcemy użyć F32[1024] jako bufora tymczasowego. Następnie zapiszemy HLO tak jak powyżej i nigdy nie odczytamy indeksu krotek 1 danych wyjściowych niestandardowego wywołania.

Kropki w wywołaniach niestandardowych procesora

W kodzie procesora znajduje się funkcja do_custom_call(const void** ins, void* out). ins to tablica z tylko jednym elementem, który wskazuje param0. Podbufory obiektu param0 są dostępne po usunięciu odniesień do tego wskaźnika, a podbufory podrzędne output_tuple są dostępne po usunięciu odniesień do out.

Kropki w wywołaniach niestandardowych GPU

W kodzie GPU znajduje się funkcja do_custom_call(..., void** buffers, ...). W tym przypadku buffers to tablica hosta z 6 wskaźnikami urządzenia, po jednym na każdy bufor liści na wejściu/wyjściu. Aby wygenerować płaską listę, powtarzamy parametry i dane wyjściowe, po czym w przypadku każdej z nich przewijamy dane przed jej kształtem. Konkretnie:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1