इस दस्तावेज़ में, XLA FFI लाइब्रेरी का इस्तेमाल करके, XLA कस्टम कॉल लिखने और इस्तेमाल करने का तरीका बताया गया है. कस्टम कॉल, XLA कंपाइलर (कंपाइल समय पर) में HLO मॉड्यूल में बाहरी "ऑपरेशन" के बारे में बताने का एक तरीका है. XLA FFI, एक्सएलए (रन टाइम पर) के साथ इस तरह की कार्रवाइयों को लागू करने को रजिस्टर करने का एक तरीका है. एफ़एफ़आई का मतलब है "फ़ॉरेन फ़ंक्शन इंटरफ़ेस". यह C API का एक सेट है, जो XLA के लिए बाइनरी इंटरफ़ेस (एबीआई) को परिभाषित करता है, ताकि दूसरी प्रोग्रामिंग भाषाओं में लिखे गए बाहरी कोड को कॉल किया जा सके. XLA, C++ में लिखे गए XLA FFI के लिए हेडर-ओनली बाइंडिंग उपलब्ध कराता है. इससे असली उपयोगकर्ता के C API में मौजूद, कम लेवल की पूरी जानकारी छिपी होती है.
सीपीयू पर पसंद के मुताबिक कॉल बनाएं
ऐसा एचएलओ निर्देश बनाया जा सकता है जो XLA के क्लाइंट एपीआई से कस्टम कॉल को दिखाता हो. उदाहरण के लिए, नीचे दिया गया कोड, सीपीयू पर A[i] = B[i %
128]+ C[i]
की गिनती करने के लिए, कस्टम कॉल का इस्तेमाल करता है. (बेशक, आपको ऐसा करना चाहिए - और करना भी चाहिए! – इसे सामान्य एचएलओ के साथ करें.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
जीपीयू पर कस्टम कॉल बनाएं
XLA FFI के साथ जीपीयू कस्टम कॉल रजिस्ट्रेशन भी करीब-करीब एक जैसा है. हालांकि, इसमें सिर्फ़ इतना अंतर है कि डिवाइस पर kernel लॉन्च करने के लिए, जीपीयू में मौजूद प्लैटफ़ॉर्म स्ट्रीम (CUDA या ROCM स्ट्रीम) की ज़रूरत पड़ेगी. यहां CUDA का एक उदाहरण दिया गया है, जो ऊपर दिए गए सीपीयू कोड की तरह ही कंप्यूटेशन (A[i] = B[i % 128] + C[i]
) करता है.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
सबसे पहले ध्यान दें कि जीपीयू कस्टम कॉल फ़ंक्शन अब भी सीपीयू पर चलाया जाने वाला फ़ंक्शन है. जीपीयू पर काम की सूची बनाने के लिए do_custom_call
सीपीयू फ़ंक्शन ज़िम्मेदार है. यहां यह एक CUDA कर्नेल लॉन्च करता है, लेकिन यह कुछ और काम भी कर सकता है,
जैसे कॉल cuBLAS.
तर्क और नतीजे, होस्ट पर भी लाइव होते हैं. साथ ही, डेटा मेंबर में डिवाइस के लिए पॉइंटर (यानी जीपीयू) मेमोरी होती है. कस्टम कॉल हैंडलर को पास किए गए बफ़र में डिवाइस बफ़र का आकार होता है, ताकि कस्टम कॉल उनसे कर्नेल लॉन्च पैरामीटर कंप्यूट कर सके.
टूल को कस्टम कॉल पर भेजना
निम्न कस्टम कॉल पर विचार करें.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
सीपीयू और जीपीयू, दोनों पर टपल को मेमोरी में पॉइंटर के अरे के तौर पर दिखाया जाता है. जब XLA, टपल आर्ग्युमेंट या नतीजों के साथ कस्टम कॉल को कॉल करता है, तो यह उन्हें फ़्लैट करता है और सामान्य बफ़र आर्ग्युमेंट या नतीजों के तौर पर पास हो जाता है.
अस्थायी बफ़र के तौर पर टपल आउटपुट
कस्टम कॉल में टपल इनपुट एक सुविधा है, लेकिन ये पूरी तरह से ज़रूरी नहीं हैं. अगर हमारे पास कस्टम कॉल के लिए टपल इनपुट की सुविधा नहीं है, तो कस्टम कॉल पर पास करने से पहले get-टपल-एलिमेंट का इस्तेमाल करके टूल को कभी भी अनपैक किया जा सकता है.
वहीं दूसरी ओर, टपल आउटपुट आपको ऐसे काम करने देते हैं जो आप नहीं कर सकते.
टपल आउटपुट होने की साफ़ वजह यह है कि टपल आउटपुट, कस्टम कॉल (या किसी भी दूसरे XLA op) के तरीके से ही कई इंडिपेंडेंट अरे देता है.
हालांकि, टपल आउटपुट से कस्टम कॉल टेंपरेचर मेमोरी देने का एक तरीका भी होता है. हां, आउटपुट अस्थायी बफ़र हो सकता है. मान लीजिए कि एक आउटपुट बफ़र में ऐसी प्रॉपर्टी है जिसे ऑपरेटर, रिपोर्ट में लिख सकता है. साथ ही, आउटपुट बफ़र में लिखे जाने के बाद, उस डेटा से उसे पढ़ सकता है. यह ठीक वैसा ही है जैसा आपको अस्थायी बफ़र से चाहिए.
ऊपर दिए गए उदाहरण में, मान लें कि हम F32[1024]
को अस्थायी बफ़र के तौर पर इस्तेमाल करना चाहते थे.
इसके बाद, हम ऊपर की तरह एचएलओ भी लिखेंगे. साथ ही, कस्टम कॉल के आउटपुट के टपल इंडेक्स 1
को कभी नहीं पढ़ा जाएगा.