XLA কাস্টম কল

এই ডকুমেন্টটি বর্ণনা করে যে কিভাবে XLA FFI লাইব্রেরি ব্যবহার করে XLA কাস্টম কল লিখতে হয় এবং ব্যবহার করতে হয়। কাস্টম কল হল XLA কম্পাইলার (কম্পাইলের সময়ে) এইচএলও মডিউলে একটি বাহ্যিক "অপারেশন" বর্ণনা করার একটি পদ্ধতি এবং XLA FFI হল XLA (চালানোর সময়) সাথে এই ধরনের ক্রিয়াকলাপ বাস্তবায়নের নিবন্ধন করার একটি প্রক্রিয়া। এফএফআই-এর অর্থ হল "বিদেশী ফাংশন ইন্টারফেস" এবং এটি C API-এর একটি সেট যা XLA-কে অন্যান্য প্রোগ্রামিং ভাষায় লিখিত বাহ্যিক কোডে কল করার জন্য একটি বাইনারি ইন্টারফেস (ABI) সংজ্ঞায়িত করে। XLA C++ এ লেখা XLA FFI-এর জন্য শুধুমাত্র শিরোনাম-বন্ধন প্রদান করে, যা শেষ ব্যবহারকারীর কাছ থেকে অন্তর্নিহিত C API-এর সমস্ত নিম্ন স্তরের বিবরণ লুকিয়ে রাখে।

JAX + XLA কাস্টম কল

JAX এর সাথে কাস্টম কল এবং XLA FFI একত্রিত করার শেষ থেকে শেষ উদাহরণগুলির জন্য JAX ডকুমেন্টেশন দেখুন৷

XLA FFI বাইন্ডিং

XLA FFI বাইন্ডিং হল কাস্টম কল সিগনেচারের একটি কম্পাইল-টাইম স্পেসিফিকেশন: কাস্টম কল আর্গুমেন্ট, অ্যাট্রিবিউট এবং তাদের প্রকার, এবং এক্সিকিউশন কনটেক্সটের মাধ্যমে পাস করা অতিরিক্ত প্যারামিটার (যেমন, GPU ব্যাকএন্ডের জন্য gpu স্ট্রিম)। XLA FFI বাইন্ডিং সামঞ্জস্যপূর্ণ operator() স্বাক্ষর সহ যেকোনো C++ কলযোগ্য (ফাংশন পয়েন্টার, ল্যাম্বডা, ইত্যাদি) সাথে আবদ্ধ হতে পারে। কনস্ট্রাক্ট হ্যান্ডলার এক্সএলএ এফএফআই কল ফ্রেম ডিকোড করে (স্থিতিশীল C API দ্বারা সংজ্ঞায়িত), টাইপ করুন সমস্ত প্যারামিটার চেক করুন এবং ডিকোড করা ফলাফল ব্যবহারকারী-সংজ্ঞায়িত কলব্যাকে ফরওয়ার্ড করুন।

XLA FFI বাইন্ডিং টেমপ্লেট মেটাপ্রোগ্রামিং-এর উপর নির্ভর করে সবচেয়ে দক্ষ মেশিন কোডে তৈরি হ্যান্ডলারকে কম্পাইল করতে সক্ষম হতে। প্রতিটি কাস্টম কল প্যারামিটারের জন্য রান টাইম ওভারহেডগুলি কয়েক ন্যানোসেকেন্ডের ক্রমানুসারে।

XLA FFI কাস্টমাইজেশন পয়েন্টগুলি টেমপ্লেট বিশেষীকরণ হিসাবে প্রয়োগ করা হয়েছে, এবং ব্যবহারকারীরা তাদের কাস্টম প্রকারগুলিকে কীভাবে ডিকোড করতে হয় তা সংজ্ঞায়িত করতে পারে, অর্থাৎ, ব্যবহারকারী-সংজ্ঞায়িত enum class প্রকারের জন্য কাস্টম ডিকোডিং সংজ্ঞায়িত করা সম্ভব।

কাস্টম কল থেকে ফিরে ত্রুটি

কাস্টম কল বাস্তবায়ন অবশ্যই xla::ffi::Error মান বা XLA রানটাইমে ত্রুটি প্রদান করবে। এটি absl::Status অনুরূপ এবং এরর কোডের একই সেট রয়েছে। আমরা absl::Status ব্যবহার করি না কারণ এটিতে একটি স্থিতিশীল ABI নেই এবং এটি গতিশীলভাবে লোড করা কাস্টম কল লাইব্রেরি এবং XLA এর মধ্যে পাস করা অনিরাপদ হবে।

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

বাফার আর্গুমেন্ট এবং ফলাফল

XLA ফলাফলের জন্য গন্তব্য পাসিং শৈলী ব্যবহার করে: কাস্টম কল (অথবা সেই বিষয়ে অন্য কোন XLA অপারেশন) ফলাফলের জন্য মেমরি বরাদ্দ করে না, এবং পরিবর্তে XLA রানটাইম দ্বারা পাস করা গন্তব্যগুলিতে লিখুন। XLA স্ট্যাটিক বাফার অ্যাসাইনমেন্ট ব্যবহার করে এবং কম্পাইলের সময় তাদের লাইভ রেঞ্জের উপর ভিত্তি করে সমস্ত মানের জন্য বাফার বরাদ্দ করে।

Result<T> টেমপ্লেটে মোড়ানো FFI হ্যান্ডলারদের কাছে পাঠানো ফলাফল, যেটিতে একটি পয়েন্টার-এর মতো শব্দার্থ আছে: operator-> অন্তর্নিহিত প্যারামিটারে অ্যাক্সেস দেয়।

AnyBuffer আর্গুমেন্ট এবং ফলাফল যেকোন ডেটা টাইপের কাস্টম কল বাফার প্যারামিটারে অ্যাক্সেস দেয়। এটি উপযোগী যখন কাস্টম কলের একটি জেনেরিক বাস্তবায়ন থাকে যা একাধিক ডাটা টাইপের জন্য কাজ করে এবং কাস্টম কল ইমপ্লিমেন্টেশন ডাটা টাইপের উপর ভিত্তি করে টাইম ডিস্প্যাচিং চালায়। AnyBuffer বাফার ডেটা টাইপ, মাত্রা, এবং বাফার নিজেই একটি পয়েন্টার অ্যাক্সেস দেয়।

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

সীমাবদ্ধ বাফার আর্গুমেন্ট এবং ফলাফল

Buffer বাফার ডেটা টাইপ এবং র‌্যাঙ্কে সীমাবদ্ধতা যুক্ত করার অনুমতি দেয় এবং সেগুলি হ্যান্ডলার দ্বারা স্বয়ংক্রিয়ভাবে পরীক্ষা করা হবে এবং যদি রান টাইম আর্গুমেন্টগুলি FFI হ্যান্ডলার স্বাক্ষরের সাথে মেলে না তবে XLA রানটাইমে একটি ত্রুটি ফেরত দেয়।

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

বৈচিত্র্যময় আর্গুমেন্ট এবং ফলাফল

যদি একটি কাস্টম কলের বিভিন্ন দৃষ্টান্তে আর্গুমেন্ট এবং ফলাফলের সংখ্যা ভিন্ন হতে পারে, তাহলে RemainingArgs এবং RemainingRets ব্যবহার করে রান টাইমে ডিকোড করা যেতে পারে।

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

নিয়মিত আর্গুমেন্ট এবং ফলাফলের পরে বৈচিত্র্যময় যুক্তি এবং ফলাফল ঘোষণা করা যেতে পারে, তবে বৈচিত্র্যময় একটির পরে নিয়মিত যুক্তি এবং ফলাফল বাঁধা অবৈধ।

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

গুণাবলী

XLA FFI mlir::DictionaryAttr এর স্বয়ংক্রিয় ডিকোডিং সমর্থন করে এফএফআই হ্যান্ডলার আর্গুমেন্টে custom_call backend_config হিসাবে পাস করা হয়েছে।

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

এই উদাহরণে কাস্টম কলের একটি একক বাফার আর্গুমেন্ট এবং দুটি বৈশিষ্ট্য রয়েছে এবং XLA FFI স্বয়ংক্রিয়ভাবে সেগুলিকে ডিকোড করতে পারে এবং ব্যবহারকারী-সংজ্ঞায়িত কলেবলে পাস করতে পারে।

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

ব্যবহারকারী-সংজ্ঞায়িত Enum বৈশিষ্ট্য

XLA FFI স্বয়ংক্রিয়ভাবে ব্যবহারকারী-সংজ্ঞায়িত enums মধ্যে অবিচ্ছেদ্য MLIR বৈশিষ্ট্য ডিকোড করতে পারে। Enum ক্লাসের একই অন্তর্নিহিত অবিচ্ছেদ্য ধরন থাকতে হবে এবং ডিকোডিংকে XLA FFI-এর সাথে স্পষ্টভাবে নিবন্ধিত হতে হবে।

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

সমস্ত কাস্টম কল বৈশিষ্ট্য বাঁধাই

একটি অভিধান হিসাবে সমস্ত কাস্টম কল অ্যাট্রিবিউটগুলিতে অ্যাক্সেস পাওয়া সম্ভব এবং অলসভাবে শুধুমাত্র রান টাইমে প্রয়োজনীয় বৈশিষ্ট্যগুলিকে ডিকোড করা সম্ভব৷

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

ব্যবহারকারী-সংজ্ঞায়িত কাঠামো বৈশিষ্ট্য

XLA FFI ব্যবহারকারী-সংজ্ঞায়িত স্ট্রাকটে অভিধান বৈশিষ্ট্যগুলিকে ডিকোড করতে পারে।

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

উদাহরণে উপরের range হল একটি mlir::DictionaryAttr অ্যাট্রিবিউট এবং নামের মাধ্যমে অভিধান ক্ষেত্রগুলি অ্যাক্সেস করার পরিবর্তে, এটি স্বয়ংক্রিয়ভাবে একটি C++ স্ট্রাকট হিসাবে ডিকোড করা যেতে পারে। ডিকোডিং একটি XLA_FFI_REGISTER_STRUCT_ATTR_DECODING ম্যাক্রোর সাথে স্পষ্টভাবে নিবন্ধিত হতে হবে (দৃশ্যের পিছনে এটি ::xla::ffi নামস্থানে একটি টেমপ্লেট বিশেষীকরণ সংজ্ঞায়িত করে, এইভাবে ম্যাক্রোকে অবশ্যই গ্লোবাল নেমস্পেসে যুক্ত করতে হবে)।

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

কাস্টম অ্যাট্রিবিউট অন্য যেকোনো অ্যাট্রিবিউটের মতোই একটি অভিধান থেকে লোড করা যেতে পারে। নীচের উদাহরণে, সমস্ত কাস্টম কল বৈশিষ্ট্য একটি Dictionary হিসাবে ডিকোড করা হয়েছে, এবং একটি range নাম দ্বারা অ্যাক্সেস করা যেতে পারে৷

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

CPU-তে একটি কাস্টম কল তৈরি করুন

আপনি একটি HLO নির্দেশ তৈরি করতে পারেন যা XLA এর ক্লায়েন্ট API এর মাধ্যমে একটি কাস্টম কল উপস্থাপন করে। উদাহরণস্বরূপ, CPU-তে A[i] = B[i % 128]+ C[i] গণনা করতে নিম্নলিখিত কোডটি একটি কাস্টম কল ব্যবহার করে। (অবশ্যই আপনি করতে পারেন - এবং করা উচিত! - নিয়মিত HLO এর সাথে এটি করুন।)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

GPU-তে একটি কাস্টম কল তৈরি করুন

XLA FFI-এর সাথে GPU কাস্টম কল রেজিস্ট্রেশন প্রায় অভিন্ন, শুধুমাত্র পার্থক্য হল যে GPU-এর জন্য আপনাকে ডিভাইসে কার্নেল চালু করতে সক্ষম হওয়ার জন্য একটি অন্তর্নিহিত প্ল্যাটফর্ম স্ট্রিম (CUDA বা ROCM স্ট্রিম) চাইতে হবে। এখানে একটি CUDA উদাহরণ রয়েছে যা উপরের CPU কোডের মতো একই গণনা ( A[i] = B[i % 128] + C[i] ) করে।

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

প্রথমে লক্ষ্য করুন যে জিপিইউ কাস্টম কল ফাংশনটি এখনও সিপিইউতে কার্যকর করা একটি ফাংশনdo_custom_call CPU ফাংশন GPU-তে কাজ সারিবদ্ধ করার জন্য দায়ী। এখানে এটি একটি CUDA কার্নেল চালু করে, তবে এটি অন্য কিছুও করতে পারে, যেমন cuBLAS কল।

আর্গুমেন্ট এবং ফলাফল হোস্টে লাইভ, এবং ডেটা সদস্য ডিভাইসে একটি পয়েন্টার (যেমন GPU) মেমরি ধারণ করে। কাস্টম কল হ্যান্ডলারে পাস করা বাফারগুলির অন্তর্নিহিত ডিভাইস বাফারের আকার থাকে, তাই কাস্টম কল তাদের থেকে কার্নেল লঞ্চ প্যারামিটারগুলি গণনা করতে পারে।

কাস্টম কলে টিপল পাস করা

নিম্নলিখিত কাস্টম কল বিবেচনা করুন.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

সিপিইউ এবং জিপিইউ উভয়েই, একটি টিপলকে পয়েন্টারগুলির একটি অ্যারে হিসাবে মেমরিতে উপস্থাপন করা হয়। যখন XLA কাস্টম কলগুলিকে টিপল আর্গুমেন্ট বা ফলাফল সহ কল ​​করে তখন এটি তাদের সমতল করে এবং নিয়মিত বাফার আর্গুমেন্ট বা ফলাফল হিসাবে পাস করে।

টেম্প বাফার হিসাবে আউটপুট টিপল

কাস্টম কলে Tuple ইনপুট একটি সুবিধা, কিন্তু তারা কঠোরভাবে প্রয়োজনীয় নয়। যদি আমরা কাস্টম কলে টিপল ইনপুট সমর্থন না করি, তাহলে আপনি কাস্টম কলে পাঠানোর আগে গেট-টুপল-এলিমেন্ট ব্যবহার করে টিপলগুলিকে সর্বদা আনপ্যাক করতে পারেন।

অন্যদিকে, টিপল আউটপুট আপনাকে এমন কিছু করতে দেয় যা আপনি অন্যথায় করতে পারেন না।

টিপল আউটপুট থাকার সুস্পষ্ট কারণ হল টিপল আউটপুট হল কিভাবে একটি কাস্টম কল (বা অন্য কোন XLA অপ) একাধিক স্বাধীন অ্যারে ফেরত দেয়।

তবে কম স্পষ্টতই, একটি টিপল আউটপুট আপনার কাস্টম কল টেম্প মেমরি দেওয়ার একটি উপায়। হ্যাঁ, একটি আউটপুট একটি টেম্প বাফার প্রতিনিধিত্ব করতে পারে। বিবেচনা করুন, একটি আউটপুট বাফারের এমন বৈশিষ্ট্য রয়েছে যা op এটিতে লিখতে পারে এবং এটি লেখার পরে এটি থেকে পড়তে পারে। যে ঠিক কি আপনি একটি টেম্প বাফার থেকে চান.

উপরের উদাহরণে, ধরুন আমরা একটি টেম্প বাফার হিসাবে F32[1024] ব্যবহার করতে চাই। তারপরে আমরা উপরের মতো HLO লিখব এবং আমরা কাস্টম কলের আউটপুটের টিপল সূচক 1 কখনই পড়ব না।