XLA कस्टम कॉल

इस दस्तावेज़ में, XLA FFI लाइब्रेरी का इस्तेमाल करके, XLA कस्टम कॉल लिखने और इस्तेमाल करने का तरीका बताया गया है. कस्टम कॉल, XLA कंपाइलर को HLO मॉड्यूल में बाहरी "ऑपरेशन" के बारे में बताने का एक तरीका है. यह काम कंपाइल टाइम पर होता है. वहीं, XLA FFI, XLA के साथ ऐसे ऑपरेशन के लागू होने की जानकारी रजिस्टर करने का एक तरीका है. यह काम रन टाइम पर होता है. FFI का मतलब "फ़ॉरेन फ़ंक्शन इंटरफ़ेस" है. यह C एपीआई का एक सेट है. यह XLA के लिए एक बाइनरी इंटरफ़ेस (ABI) तय करता है, ताकि अन्य प्रोग्रामिंग भाषाओं में लिखे गए बाहरी कोड को कॉल किया जा सके. XLA, C++ में लिखे गए XLA FFI के लिए सिर्फ़ हेडर बाइंडिंग उपलब्ध कराता है. इससे, C एपीआई की सभी जानकारी, उपयोगकर्ता से छिपी रहती है.

JAX + XLA कस्टम कॉल

कस्टम कॉल और XLA FFI को JAX के साथ इंटिग्रेट करने के पूरे उदाहरणों के लिए, JAX का दस्तावेज़ देखें.

XLA FFI बाइंडिंग

XLA FFI बाइंडिंग, कंपाइल-टाइम में कस्टम कॉल सिग्नेचर की खास जानकारी होती है: कस्टम कॉल के तर्क, एट्रिब्यूट, और उनके टाइप, और एक्ज़ीक्यूशन कॉन्टेक्स्ट के ज़रिए पास किए गए अतिरिक्त पैरामीटर (जैसे, GPU बैकएंड के लिए GPU स्ट्रीम). XLA FFI बाइंडिंग को, C++ के किसी भी कॉल किए जा सकने वाले फ़ंक्शन (फ़ंक्शन पॉइंटर, लैम्डा वगैरह) से बाइंड किया जा सकता है. इसके लिए, operator() का सिग्नेचर कंपैटिबल होना चाहिए. कंस्ट्रक्टेड हैंडलर, XLA FFI कॉल फ़्रेम को डिकोड करता है. इसे स्टेबल सी एपीआई से तय किया जाता है. साथ ही, यह सभी पैरामीटर के टाइप की जांच करता है और डिकोड किए गए नतीजों को उपयोगकर्ता के तय किए गए कॉलबैक पर फ़ॉरवर्ड करता है.

XLA FFI बाइंडिंग, टेंप्लेट मेटाप्रोग्रामिंग पर काफ़ी हद तक निर्भर करती है, ताकि बनाए गए हैंडलर को सबसे असरदार मशीन कोड में कंपाइल किया जा सके. हर कस्टम कॉल पैरामीटर के लिए, रन टाइम ओवरहेड कुछ नैनोसेकंड के क्रम में होते हैं.

XLA FFI के कस्टम पॉइंट, टेंप्लेट स्पेशलाइज़ेशन के तौर पर लागू किए जाते हैं.साथ ही, उपयोगकर्ता यह तय कर सकते हैं कि उनके कस्टम टाइप को कैसे डिकोड किया जाए. इसका मतलब है कि उपयोगकर्ता के तय किए गए enum class टाइप के लिए, कस्टम डिकोडिंग तय की जा सकती है.

कस्टम कॉल से गड़बड़ियां वापस लाना

कस्टम कॉल लागू करने पर, XLA रनटाइम को यह बताने के लिए xla::ffi::Error वैल्यू दिखानी होगी कि कॉल पूरा हो गया है या उसमें कोई गड़बड़ी हुई है. यह absl::Status जैसा ही है और इसमें गड़बड़ी के कोड का एक ही सेट है. हम absl::Status का इस्तेमाल नहीं करते, क्योंकि इसमें स्टेबल एबीआई नहीं है. साथ ही, इसे डाइनैमिक तौर पर लोड की गई कस्टम कॉल लाइब्रेरी और XLA के बीच पास करना सुरक्षित नहीं होगा.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Buffer Arguments And Results

XLA, नतीजों के लिए डेस्टिनेशन पासिंग स्टाइल का इस्तेमाल करता है: कस्टम कॉल (या XLA की कोई अन्य कार्रवाई) नतीजों के लिए मेमोरी असाइन नहीं करते हैं. इसके बजाय, XLA रनटाइम से पास किए गए डेस्टिनेशन में लिखते हैं. XLA, स्टैटिक बफ़र असाइनमेंट का इस्तेमाल करता है. साथ ही, कंपाइल टाइम में सभी वैल्यू के लिए बफ़र असाइन करता है. यह बफ़र, वैल्यू की लाइव रेंज के आधार पर असाइन किए जाते हैं.

FFI हैंडलर को पास किए गए नतीजे, Result<T> टेंप्लेट में रैप किए जाते हैं. इसमें पॉइंटर की तरह सिमैंटिक्स होता है: operator-> से, पैरामीटर को ऐक्सेस किया जा सकता है.

AnyBuffer आर्ग्युमेंट और नतीजे, किसी भी डेटा टाइप के कस्टम कॉल बफ़र पैरामीटर को ऐक्सेस करने की सुविधा देते हैं. यह तब काम आता है, जब कस्टम कॉल में एक सामान्य तरीके से लागू किया गया फ़ंक्शन होता है, जो कई तरह के डेटा टाइप के लिए काम करता है. साथ ही, कस्टम कॉल को लागू करने के दौरान, डेटा टाइप के आधार पर डिस्पैच किया जाता है. AnyBuffer से बफ़र डेटा टाइप, डाइमेंशन, और बफ़र के पॉइंटर का ऐक्सेस मिलता है.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

बफ़र के सीमित तर्क और नतीजे

Buffer की मदद से, बफ़र डेटा टाइप और डाइमेंशन की संख्या पर पाबंदियां लगाई जा सकती हैं. साथ ही, हैंडलर इनकी अपने-आप जांच करेगा. अगर रन टाइम के तर्क, FFI हैंडलर के हस्ताक्षर से मेल नहीं खाते हैं, तो XLA रनटाइम को गड़बड़ी का मैसेज मिलेगा.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

वैरिएडिक आर्ग्युमेंट और नतीजे

अगर किसी कस्टम कॉल के अलग-अलग इंस्टेंस में, आर्ग्युमेंट और नतीजे अलग-अलग हो सकते हैं, तो उन्हें रन टाइम में RemainingArgs और RemainingRets का इस्तेमाल करके डिकोड किया जा सकता है.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

वैरिएडिक आर्ग्युमेंट और नतीजे, सामान्य आर्ग्युमेंट और नतीजों के बाद घोषित किए जा सकते हैं. हालांकि, वैरिएडिक आर्ग्युमेंट और नतीजों के बाद सामान्य आर्ग्युमेंट और नतीजों को बाइंड करना गैर-कानूनी है.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

विशेषताएं

XLA FFI, mlir::DictionaryAttr को custom_call backend_config के तौर पर पास किए जाने पर, उसे FFI हैंडलर आर्ग्युमेंट में अपने-आप डिकोड करने की सुविधा देता है.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

इस उदाहरण में, कस्टम कॉल में एक बफ़र आर्ग्युमेंट और दो एट्रिब्यूट हैं. साथ ही, XLA FFI इन्हें अपने-आप डिकोड कर सकता है और उपयोगकर्ता की ओर से तय किए गए कॉल करने लायक फ़ंक्शन को पास कर सकता है.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

उपयोगकर्ता के तय किए गए enum एट्रिब्यूट

XLA FFI, इंटिग्रल MLIR एट्रिब्यूट को उपयोगकर्ता के तय किए गए enum में अपने-आप डिकोड कर सकता है. एनम क्लास का इंटिग्रल टाइप एक ही होना चाहिए. साथ ही, डिकोडिंग को XLA FFI के साथ साफ़ तौर पर रजिस्टर करना होगा.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

सभी कस्टम कॉल एट्रिब्यूट बाइंड करना

सभी कस्टम कॉल एट्रिब्यूट को डिक्शनरी के तौर पर ऐक्सेस किया जा सकता है. साथ ही, सिर्फ़ उन एट्रिब्यूट को डिकोड किया जा सकता है जिनकी ज़रूरत रन टाइम में होती है.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

उपयोगकर्ता के तय किए गए स्ट्रक्ट एट्रिब्यूट

XLA FFI, डिक्शनरी एट्रिब्यूट को उपयोगकर्ता के तय किए गए स्ट्रक्चर में डिकोड कर सकता है.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

ऊपर दिए गए उदाहरण में, range एक mlir::DictionaryAttr एट्रिब्यूट है. साथ ही, डिक्शनरी फ़ील्ड को नाम से ऐक्सेस करने के बजाय, इसे C++ स्ट्रक्चर के तौर पर अपने-आप डिकोड किया जा सकता है. डिकोडिंग को XLA_FFI_REGISTER_STRUCT_ATTR_DECODING मैक्रो के साथ साफ़ तौर पर रजिस्टर करना होगा. पर्दे के पीछे, यह ::xla::ffi नेमस्पेस में टेंप्लेट स्पेशलाइज़ेशन तय करता है. इसलिए, मैक्रो को ग्लोबल नेमस्पेस में जोड़ना होगा.

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

कस्टम एट्रिब्यूट को किसी शब्दकोश से लोड किया जा सकता है. ठीक वैसे ही जैसे किसी दूसरे एट्रिब्यूट को लोड किया जाता है. यहां दिए गए उदाहरण में, सभी कस्टम कॉल एट्रिब्यूट को Dictionary के तौर पर डिकोड किया गया है. साथ ही, range को नाम के हिसाब से ऐक्सेस किया जा सकता है.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

सीपीयू पर कस्टम कॉल बनाना

XLA के क्लाइंट एपीआई के ज़रिए, कस्टम कॉल को दिखाने वाला एचएलओ निर्देश बनाया जा सकता है. उदाहरण के लिए, इस कोड में सीपीयू पर A[i] = B[i % 128]+ C[i] को कंप्यूट करने के लिए, कस्टम कॉल का इस्तेमाल किया गया है. (बिलकुल, ऐसा किया जा सकता है और करना भी चाहिए! – do this with regular HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

जीपीयू पर कस्टम कॉल बनाना

XLA FFI के साथ जीपीयू कस्टम कॉल रजिस्ट्रेशन लगभग एक जैसा होता है. सिर्फ़ इतना अंतर होता है कि जीपीयू के लिए, आपको डिवाइस पर कर्नल लॉन्च करने के लिए, प्लैटफ़ॉर्म स्ट्रीम (CUDA या ROCM स्ट्रीम) का अनुरोध करना होता है. यहां CUDA का एक उदाहरण दिया गया है. इसमें वही हिसाब (A[i] = B[i % 128] + C[i]) किया गया है जो ऊपर दिए गए सीपीयू कोड में किया गया है.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

सबसे पहले ध्यान दें कि GPU कस्टम कॉल फ़ंक्शन अब भी सीपीयू पर एक्ज़ीक्यूट होने वाला फ़ंक्शन है. do_custom_call सीपीयू फ़ंक्शन, जीपीयू पर काम को कतार में लगाने के लिए ज़िम्मेदार होता है. यहां यह CUDA कर्नल लॉन्च करता है. हालांकि, यह cuBLAS को कॉल करने जैसी कोई अन्य कार्रवाई भी कर सकता है.

आर्ग्युमेंट और नतीजे भी होस्ट पर मौजूद होते हैं. साथ ही, डेटा मेंबर में डिवाइस (यानी कि जीपीयू) की मेमोरी का पॉइंटर होता है. कस्टम कॉल हैंडलर को पास किए गए बफ़र, डिवाइस के बफ़र के आकार के होते हैं. इसलिए, कस्टम कॉल उनसे कर्नल लॉन्च पैरामीटर का हिसाब लगा सकता है.

कस्टम कॉल में टपल पास करना

यहां कस्टम कॉल का एक उदाहरण दिया गया है.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

सीपीयू और जीपीयू, दोनों पर टपल को मेमोरी में पॉइंटर की कैटगरी के तौर पर दिखाया जाता है. जब XLA, टपल आर्ग्युमेंट या नतीजों के साथ कस्टम कॉल करता है, तो यह उन्हें फ़्लैट कर देता है और सामान्य बफ़र आर्ग्युमेंट या नतीजों के तौर पर पास करता है.

टपल को अस्थायी बफ़र के तौर पर आउटपुट करता है

कस्टम कॉल के लिए टपल इनपुट का इस्तेमाल करना आसान होता है, लेकिन यह ज़रूरी नहीं है. अगर हम कस्टम कॉल के लिए टपल इनपुट का इस्तेमाल नहीं करते हैं, तो कस्टम कॉल को टपल पास करने से पहले, get-tuple-element का इस्तेमाल करके टपल को अनपैक किया जा सकता है.

दूसरी ओर, टपल आउटपुट से ऐसे काम किए जा सकते हैं जो किसी और तरीके से नहीं किए जा सकते.

टपल आउटपुट का इस्तेमाल करने की मुख्य वजह यह है कि कस्टम कॉल या कोई अन्य XLA ऑप, कई इंडिपेंडेंट ऐरे को टपल आउटपुट के तौर पर दिखाता है.

हालांकि, टपल आउटपुट का इस्तेमाल करके भी कस्टम कॉल को कुछ समय के लिए मेमोरी दी जा सकती है. हां, आउटपुट, टेंप बफ़र को दिखा सकता है. मान लें कि किसी आउटपुट बफ़र में यह प्रॉपर्टी होती है कि ऑप इसमें लिख सकता है. साथ ही, इसमें लिखने के बाद इसे पढ़ा जा सकता है. आपको एक अस्थायी बफ़र से यही चाहिए.

ऊपर दिए गए उदाहरण में, मान लें कि हमें F32[1024] को अस्थायी बफ़र के तौर पर इस्तेमाल करना है. इसके बाद, हम ऊपर की तरह ही एचएलओ लिखेंगे. साथ ही, हम कस्टम कॉल के आउटपुट के टपल इंडेक्स 1 को कभी नहीं पढ़ेंगे.