Bu dokümanda, XLA FFI kitaplığını kullanarak XLA özel çağrılarının nasıl yazılacağı ve kullanılacağı açıklanmaktadır. Özel çağrı, HLO modülündeki harici bir "işlemi" XLA derleyicisine (derleme zamanında) açıklamak için kullanılan bir mekanizmadır. XLA FFI, bu tür işlemlerin XLA'ya (çalışma zamanında) kaydedilmesini sağlayan bir mekanizmadır. FFI, "yabancı işlev arayüzü" anlamına gelir. XLA'nın diğer programlama dillerinde yazılmış harici bir kodu çağırması için ikili arayüzü (ABI) tanımlayan bir C API kümesidir. XLA, C++'ta yazılmış XLA FFI için yalnızca başlıktan bağlamalar sağlar. Böylece, temel C API'lerinin tüm alt düzey ayrıntılarını son kullanıcıdan gizler.
CPU'da özel bir çağrı oluştur
XLA'nın istemci API'si aracılığıyla özel bir çağrıyı temsil eden bir HLO talimatı oluşturabilirsiniz. Örneğin, aşağıdaki kodda CPU'da A[i] = B[i %
128]+ C[i]
hesaplamak için özel bir çağrı kullanılmaktadır. (Elbette yapabilirsiniz ve yapmalısınız! bunu normal HLO ile yapın.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
GPU'da özel bir çağrı oluştur
XLA FFI ile GPU özel çağrı kaydı neredeyse aynıdır. Aradaki tek fark, GPU'da çekirdeği cihazda başlatabilmek için temel bir platform akışı (CUDA veya ROCM akışı) istemenizin gerekmesidir. Yukarıdaki CPU koduyla aynı hesaplamayı (A[i] = B[i % 128] + C[i]
) yapan bir CUDA örneğini burada bulabilirsiniz.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
Öncelikle GPU özel çağrı işlevinin hala CPU üzerinde yürütülen bir işlev olduğuna dikkat edin. do_custom_call
CPU işlevi, GPU'daki işleri sıraya koymaktan sorumludur. Burada bir CUDA çekirdeği başlatır, ancak cuBLAS'yi çağırma gibi başka bir işlemi de yapabilir.
Bağımsız değişkenler ve sonuçlar da ana makinede yer alır ve veri üyesi, cihaz işaretçisi (ör. GPU) bir bellek içerir. Özel çağrı işleyiciye iletilen arabellekler, temel cihaz arabellekleri biçimindedir. Böylece özel çağrı, çekirdek başlatma parametrelerini bunlardan hesaplayabilir.
Tuple'ları özel çağrılara iletme
Aşağıdaki özel aramayı kullanabilirsiniz.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
Hem CPU'da hem de GPU'da bir unsur, bellekte işaretçi dizisi olarak gösterilir. XLA, unsur bağımsız değişkenleri veya sonuçlar içeren özel çağrılar çağırdığında bunları düzleştirir ve normal tampon bağımsız değişkenleri veya sonuçlar olarak geçirir.
Tuple, geçici arabellek olarak çıkış yapar
Özel çağrılara çift giriş yapmak kolaylık olsa da kesin bir zorunluluk değildir. Özel çağrılara tuple girişleri desteklemiyor olsaydık özel çağrıya iletmeden önce get-tuple-öğesini kullanarak grupları her zaman paketinden çıkarabilirdiniz.
Öte yandan, çift çıkışlar, başka şekilde yapamayacağınız şeyleri yapmanıza olanak tanır.
Tuple çıkışlarının olmasının en bariz nedeni, tuple çıkışlarının bir özel çağrının (veya başka herhangi bir XLA işlem operatörünün) birden çok bağımsız dizi döndürme şekli olmasıdır.
Ancak daha az kesin olmak gerekirse özel çağrı geçici belleğinizi tuple çıkışıyla da sağlayabilirsiniz. Evet, çıkış geçici arabelleği temsil edebilir. Çıktı arabelleğinin, işlemin üzerine yazabileceği özelliğe sahip olduğunu ve yazıldıktan sonra buradan okuyabileceğini göz önünde bulundurun. Geçici arabelleğe alma işleminden tam olarak istediğin şey budur.
Yukarıdaki örnekte F32[1024]
öğesini geçici arabellek olarak kullanmak istediğimizi varsayalım.
Daha sonra HLO'yu tam olarak yukarıdaki gibi yazardık ve özel çağrı çıkışının 1. tuple dizinini hiçbir zaman okumayız.