Benutzerdefinierte XLA-Aufrufe

In diesem Dokument wird beschrieben, wie benutzerdefinierte XLA-Aufrufe geschrieben und verwendet werden. Mit benutzerdefinierten Aufrufen können Sie Code aufrufen, der in einer Programmiersprache wie C++ oder CUDA aus einem XLA-Programm geschrieben wurde.

Benutzerdefinierten Aufruf auf der CPU erstellen

Sie können über die Client API von XLA eine HLO-Anweisung erstellen, die einen benutzerdefinierten Aufruf darstellt. Im folgenden Code wird beispielsweise ein benutzerdefinierter Aufruf verwendet, um A[i] = B[i % 128]+ C[i] auf der CPU zu berechnen. (Natürlich können und sollten Sie! – tun Sie dies mit normaler HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Beachten Sie, dass die Funktion do_custom_call die Dimensionen der Puffer kennen muss, mit denen sie arbeitet. In diesem Beispiel wurden die Größen 128 und 2048 hartcodiert. Wenn Sie das nicht tun möchten, können Sie die Dimensionen als Parameter an den Aufruf übergeben.

Benutzerdefinierten Aufruf auf der GPU erstellen

Das benutzerdefinierte GPU-Aufruf-Framework unterscheidet sich etwas von dem auf der CPU. Hier ist ein CUDA-Beispiel, das dieselbe Berechnung (A[i] = B[i % 128] + C[i]) wie der obige CPU-Code ausführt.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Beachten Sie zunächst, dass die benutzerdefinierte GPU-Aufruffunktion immer noch eine Funktion ist, die auf der CPU ausgeführt wird. Die CPU-Funktion do_custom_call ist für das Einreihen der Arbeit auf der GPU verantwortlich. Hier startet er einen CUDA-Kernel, könnte aber auch etwas anderes tun, z. B. cuBLAS aufrufen.

buffers ist ein Array von Zeigern auf dem Host und jedes darin enthaltene Element Punkte zum Gerätespeicher (z.B. GPU). Die Parameter kommen zuerst, gefolgt vom Ausgabewert. Dies unterscheidet sich erheblich von der CPU-Aufrufkonvention, die zwei Parameter hat: ins und out. Mit der GPU-Aufrufkonvention ist es möglich, tupelförmige Ein- und Ausgaben effizient zu verarbeiten.

Wie im CPU-Beispiel haben wir die Größe des Eingabe- und Ausgabepuffers in unseren benutzerdefinierten Aufruf hartcodiert. Anders als im CPU-Fall würde es jedoch nicht gut funktionieren, die Puffergrößen als Operanden an den benutzerdefinierten Aufruf zu übergeben. Normalerweise benötigen wir die Puffergrößen, die uns auf der CPU zur Verfügung stehen (z.B. müssen wir beim Starten eines Kernels die zu verwendenden Block-/Rastergrößen kennen). Wenn wir jedoch die Puffergrößen als Operanden an unseren benutzerdefinierten Aufruf übergeben würden, würden ihre Werte im GPU-Arbeitsspeicher gespeichert sein. Wir müssten dann zu Beginn unseres Vorgangs einen teuren synchronen memcpy-Vorgang zwischen Gerät und Host ausführen, um die Größen zu lesen.

Zur Umgehung dieses Problems stellen wir den Parameter opaque bereit. Sie können dies beim Erstellen des benutzerdefinierten Aufrufs auf einen beliebigen Byte-String festlegen:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Da xla::Shape eine Protokollpufferdarstellung hat, können Sie dieses serialisierte Proto in opaque speichern und in Ihrem benutzerdefinierten GPU-Aufruf deserialisieren. xla::ShapeProto ändert sich zwar nicht häufig, ändert sich aber. Sehen Sie im Git-Log nach, wie es sich in der Vergangenheit geändert hat.

Fehler signalisieren

Wenn beim benutzerdefinierten Aufruf ein Fehler auftritt, können Sie den Fehler der XLA-Laufzeit signalisieren, anstatt ihn z.B. abstürzen oder Unsinn in den Ausgabepuffern zurückzugeben, indem Sie die folgende Signatur für Ihre Funktion verwenden:

Auf der CPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

auf der GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Sie können Fehler mit XlaCustomCallStatusSetFailure signalisieren. Beispiel:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Sie können auch XlaCustomCallStatusSetSuccess verwenden, um den Erfolg anzugeben. Die XlaCustomCallStatus ist jedoch standardmäßig im Status „Erfolg“. Wenn Sie sie vollständig ignorieren, bedeutet das ebenfalls, dass der Vorgang erfolgreich war.

Wenn Sie mit dieser Signatur benutzerdefinierte Aufruffunktionen verwenden, müssen Sie die entsprechende custom-call-Operation mit dem entsprechenden API-Versionssatz erstellen. Beispiel:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

Bei einem Fehler wird keine der benutzerdefinierten Aufrufausgaben verwendet. Die XLA-Laufzeit beendet die Berechnung. Für eine HLO-Berechnung ist es nicht möglich, den Fehler wiederherzustellen (z.B. durch Abfangen und Verarbeiten).

Tupel an benutzerdefinierte Aufrufe übergeben

Betrachten Sie den folgenden benutzerdefinierten Aufruf.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Sowohl auf der CPU als auch auf der GPU wird ein Tupel im Speicher als Array von Zeigern dargestellt. Im C++-Pseudocode ist der Parameter 0 oben wie folgt festgelegt.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Obwohl die speicherinterne Darstellung von Tupeln in CPU und GPU identisch ist, werden sie in den Aufrufkonventionen für benutzerdefinierte CPU- und GPU-Aufrufe unterschiedlich behandelt.

Tupel-Ausgaben als temporäre Zwischenspeicher

Tupel-Eingaben für benutzerdefinierte Aufrufe sind praktisch, aber nicht unbedingt erforderlich. Wenn wir keine Tupeleingaben für benutzerdefinierte Aufrufe unterstützen, können Sie die Tupel immer mit dem get-tuple-element entpacken, bevor Sie sie an den benutzerdefinierten Aufruf übergeben.

Auf der anderen Seite ermöglichen Tupel-Ausgaben Ihnen Dinge, die Ihnen sonst nicht möglich wären.

Der offensichtliche Grund für Tupel-Ausgaben ist, dass Tupel-Ausgaben die Art und Weise sind, wie ein benutzerdefinierter Aufruf (oder ein anderer XLA-Vorgang) mehrere unabhängige Arrays zurückgibt.

Eine Tupelausgabe ist jedoch auch eine Möglichkeit, dem benutzerdefinierten Aufruf temporärer Speicher bereitzustellen. Ja, eine Ausgabe kann einen temporären Puffer darstellen. Ein Ausgabepuffer hat die Eigenschaft, in die der Vorgang schreiben kann. Nachdem er geschrieben wurde, kann er daraus lesen. Das ist genau das, was du von einem Temperaturpuffer erwartest.

Angenommen, wir im obigen Beispiel möchten F32[1024] als temporären Zwischenspeicher verwenden. Dann schreiben wir den HLO wie oben beschrieben und lesen einfach nie den Tupelindex 1 der Ausgabe des benutzerdefinierten Aufrufs.

Tupel in benutzerdefinierten CPU-Aufrufen

Im CPU-Code haben wir die Funktion do_custom_call(const void** ins, void* out). ins ist ein Array mit nur einem Element, das auf param0 verweist. Der Zugriff auf die Unterpuffer von param0 ist durch Dereferenzieren dieses Zeigers und auf die Unterpuffer von output_tuple durch Dereferenzierung von out möglich.

Tupel in benutzerdefinierten GPU-Aufrufen

Im GPU-Code haben wir die Funktion do_custom_call(..., void** buffers, ...). In diesem Fall ist buffers ein Hostarray mit sechs Gerätezeigern, einer für jeden Blattpuffer in der Ein-/Ausgabe. Um die Flat List zu generieren, iterieren wir über die Parameter und die Ausgabe und führen für jeden einen Vorbestellungsdurchlauf seiner Form durch. Konkret:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1