In diesem Dokument wird beschrieben, wie benutzerdefinierte XLA-Aufrufe geschrieben und verwendet werden. Mit benutzerdefinierten Aufrufen können Sie Code aufrufen, der in einer Programmiersprache wie C++ oder CUDA aus einem XLA-Programm geschrieben wurde.
Benutzerdefinierten Aufruf auf der CPU erstellen
Sie können über die Client API von XLA eine HLO-Anweisung erstellen, die einen benutzerdefinierten Aufruf darstellt. Im folgenden Code wird beispielsweise ein benutzerdefinierter Aufruf verwendet, um A[i] = B[i %
128]+ C[i]
auf der CPU zu berechnen. (Natürlich können und sollten Sie! – tun Sie dies mit
normaler HLO.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}
void do_custom_call(void* out, const void** in) {
float* out_buf = reinterpret_cast<float*>(out);
const float* in0 = reinterpret_cast<const float*>(in[0]);
const float* in1 = reinterpret_cast<const float*>(in[1]);
for (int i = 0; i < 2048; ++i) {
out_buf[i] = in0[i % 128] + in1[i];
}
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
Beachten Sie, dass die Funktion do_custom_call
die Dimensionen der Puffer kennen muss, mit denen sie arbeitet. In diesem Beispiel wurden die Größen 128
und 2048
hartcodiert. Wenn Sie das nicht tun möchten, können Sie die Dimensionen als Parameter an den Aufruf übergeben.
Benutzerdefinierten Aufruf auf der GPU erstellen
Das benutzerdefinierte GPU-Aufruf-Framework unterscheidet sich etwas von dem auf der CPU. Hier ist ein CUDA-Beispiel, das dieselbe Berechnung (A[i] = B[i % 128] + C[i]
) wie der obige CPU-Code ausführt.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
Beachten Sie zunächst, dass die benutzerdefinierte GPU-Aufruffunktion immer noch eine Funktion ist, die auf der CPU ausgeführt wird. Die CPU-Funktion do_custom_call
ist für das Einreihen der Arbeit auf der GPU verantwortlich. Hier startet er einen CUDA-Kernel, könnte aber auch etwas anderes tun, z. B. cuBLAS aufrufen.
buffers
ist ein Array von Zeigern auf dem Host und jedes darin enthaltene Element Punkte zum Gerätespeicher (z.B. GPU). Die Parameter kommen zuerst, gefolgt vom Ausgabewert. Dies unterscheidet sich erheblich von der CPU-Aufrufkonvention, die zwei Parameter hat: ins
und out
. Mit der GPU-Aufrufkonvention ist es möglich, tupelförmige Ein- und Ausgaben effizient zu verarbeiten.
Wie im CPU-Beispiel haben wir die Größe des Eingabe- und Ausgabepuffers in unseren benutzerdefinierten Aufruf hartcodiert. Anders als im CPU-Fall würde es jedoch nicht gut funktionieren, die Puffergrößen als Operanden an den benutzerdefinierten Aufruf zu übergeben. Normalerweise benötigen wir die Puffergrößen, die uns auf der CPU zur Verfügung stehen (z.B. müssen wir beim Starten eines Kernels die zu verwendenden Block-/Rastergrößen kennen). Wenn wir jedoch die Puffergrößen als Operanden an unseren benutzerdefinierten Aufruf übergeben würden, würden ihre Werte im GPU-Arbeitsspeicher gespeichert sein. Wir müssten dann zu Beginn unseres Vorgangs einen teuren synchronen memcpy
-Vorgang zwischen Gerät und Host ausführen, um die Größen zu lesen.
Zur Umgehung dieses Problems stellen wir den Parameter opaque
bereit. Sie können dies beim Erstellen des benutzerdefinierten Aufrufs auf einen beliebigen Byte-String festlegen:
std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
opaque);
Da xla::Shape
eine Protokollpufferdarstellung hat, können Sie dieses serialisierte Proto in opaque
speichern und in Ihrem benutzerdefinierten GPU-Aufruf deserialisieren. xla::ShapeProto
ändert sich zwar nicht häufig, ändert sich aber. Sehen Sie im Git-Log nach, wie es sich in der Vergangenheit geändert hat.
Fehler signalisieren
Wenn beim benutzerdefinierten Aufruf ein Fehler auftritt, können Sie den Fehler der XLA-Laufzeit signalisieren, anstatt ihn z.B. abstürzen oder Unsinn in den Ausgabepuffern zurückzugeben, indem Sie die folgende Signatur für Ihre Funktion verwenden:
Auf der CPU:
#include "xla/service/custom_call_status.h"
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);
auf der GPU:
#include "xla/service/custom_call_status.h"
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, xla::XlaCustomCallStatus* status);
Sie können Fehler mit XlaCustomCallStatusSetFailure
signalisieren. Beispiel:
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
// ... do some work.
if (bad_condition) {
char* error_message = "An error occurred";
XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
return;
}
// ... continue.
}
Sie können auch XlaCustomCallStatusSetSuccess
verwenden, um den Erfolg anzugeben. Die XlaCustomCallStatus
ist jedoch standardmäßig im Status „Erfolg“. Wenn Sie sie vollständig ignorieren, bedeutet das ebenfalls, dass der Vorgang erfolgreich war.
Wenn Sie mit dieser Signatur benutzerdefinierte Aufruffunktionen verwenden, müssen Sie die entsprechende custom-call
-Operation mit dem entsprechenden API-Versionssatz erstellen. Beispiel:
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
opaque, /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/API_VERSION_STATUS_RETURNING);
Bei einem Fehler wird keine der benutzerdefinierten Aufrufausgaben verwendet. Die XLA-Laufzeit beendet die Berechnung. Für eine HLO-Berechnung ist es nicht möglich, den Fehler wiederherzustellen (z.B. durch Abfangen und Verarbeiten).
Tupel an benutzerdefinierte Aufrufe übergeben
Betrachten Sie den folgenden benutzerdefinierten Aufruf.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
Sowohl auf der CPU als auch auf der GPU wird ein Tupel im Speicher als Array von Zeigern dargestellt. Im C++-Pseudocode ist der Parameter 0 oben wie folgt festgelegt.
// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];
void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;
void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;
Obwohl die speicherinterne Darstellung von Tupeln in CPU und GPU identisch ist, werden sie in den Aufrufkonventionen für benutzerdefinierte CPU- und GPU-Aufrufe unterschiedlich behandelt.
Tupel-Ausgaben als temporäre Zwischenspeicher
Tupel-Eingaben für benutzerdefinierte Aufrufe sind praktisch, aber nicht unbedingt erforderlich. Wenn wir keine Tupeleingaben für benutzerdefinierte Aufrufe unterstützen, können Sie die Tupel immer mit dem get-tuple-element entpacken, bevor Sie sie an den benutzerdefinierten Aufruf übergeben.
Auf der anderen Seite ermöglichen Tupel-Ausgaben Ihnen Dinge, die Ihnen sonst nicht möglich wären.
Der offensichtliche Grund für Tupel-Ausgaben ist, dass Tupel-Ausgaben die Art und Weise sind, wie ein benutzerdefinierter Aufruf (oder ein anderer XLA-Vorgang) mehrere unabhängige Arrays zurückgibt.
Eine Tupelausgabe ist jedoch auch eine Möglichkeit, dem benutzerdefinierten Aufruf temporärer Speicher bereitzustellen. Ja, eine Ausgabe kann einen temporären Puffer darstellen. Ein Ausgabepuffer hat die Eigenschaft, in die der Vorgang schreiben kann. Nachdem er geschrieben wurde, kann er daraus lesen. Das ist genau das, was du von einem Temperaturpuffer erwartest.
Angenommen, wir im obigen Beispiel möchten F32[1024]
als temporären Zwischenspeicher verwenden.
Dann schreiben wir den HLO wie oben beschrieben und lesen einfach nie den Tupelindex 1 der Ausgabe des benutzerdefinierten Aufrufs.
Tupel in benutzerdefinierten CPU-Aufrufen
Im CPU-Code haben wir die Funktion do_custom_call(const void** ins, void* out)
.
ins
ist ein Array mit nur einem Element, das auf param0
verweist. Der Zugriff auf die Unterpuffer von param0
ist durch Dereferenzieren dieses Zeigers und auf die Unterpuffer von output_tuple
durch Dereferenzierung von out
möglich.
Tupel in benutzerdefinierten GPU-Aufrufen
Im GPU-Code haben wir die Funktion do_custom_call(..., void** buffers, ...)
. In diesem Fall ist buffers
ein Hostarray mit sechs Gerätezeigern, einer für jeden Blattpuffer in der Ein-/Ausgabe. Um die Flat List zu generieren, iterieren wir über die Parameter und die Ausgabe und führen für jeden einen Vorbestellungsdurchlauf seiner Form durch.
Konkret:
// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1