تماس های سفارشی XLA

این سند نحوه نوشتن و استفاده از تماس های سفارشی XLA با استفاده از کتابخانه XLA FFI را شرح می دهد. فراخوانی سفارشی مکانیزمی است برای توصیف یک "عملیات" خارجی در ماژول HLO به کامپایلر XLA (در زمان کامپایل)، و XLA FFI مکانیزمی است برای ثبت اجرای چنین عملیاتی با XLA (در زمان اجرا). FFI مخفف "interface function خارجی" است و مجموعه ای از C API است که یک رابط باینری (ABI) را برای XLA برای فراخوانی کد خارجی نوشته شده در سایر زبان های برنامه نویسی تعریف می کند. XLA اتصالات فقط هدر را برای XLA FFI نوشته شده در C++ فراهم می کند، که تمام جزئیات سطح پایین API های زیرین C را از کاربر نهایی پنهان می کند.

تماس های سفارشی JAX + XLA

برای نمونه‌های پایانی به انتها از ادغام تماس‌های سفارشی و XLA FFI با JAX، به مستندات JAX مراجعه کنید.

XLA FFI Binding

XLA FFI binding یک مشخصه زمان کامپایل برای امضای فراخوان سفارشی است: آرگومان های فراخوانی سفارشی، ویژگی ها و انواع آنها، و پارامترهای اضافی که از طریق زمینه اجرا (یعنی جریان gpu برای باطن GPU) ارسال می شود. یافتن XLA FFI را می توان به هر C++ قابل فراخوانی (اشاره گر تابع، لامبدا و غیره) با امضای operator() سازگار متصل کرد. کنترلر ساخته شده فریم فراخوانی XLA FFI را رمزگشایی می کند (تعریف شده توسط C API پایدار)، بررسی تمام پارامترها را تایپ می کند و نتایج رمزگشایی شده را به بازخوانی تعریف شده توسط کاربر ارسال می کند.

اتصال XLA FFI به شدت به فرابرنامه‌نویسی قالب متکی است تا بتواند هندلر ساخته شده را به کارآمدترین کد ماشین کامپایل کند. سربار زمان اجرا به ترتیب چند نانوثانیه برای هر پارامتر تماس سفارشی است.

نقاط سفارشی‌سازی XLA FFI به‌عنوان تخصص‌های قالب پیاده‌سازی می‌شوند و کاربران می‌توانند نحوه رمزگشایی انواع سفارشی خود را تعریف کنند، به عنوان مثال، امکان تعریف رمزگشایی سفارشی برای انواع enum class تعریف شده توسط کاربر وجود دارد.

بازگشت خطاها از تماس های سفارشی

اجرای فراخوان سفارشی باید xla::ffi::Error به سیگنال موفقیت یا خطا به زمان اجرا XLA برگرداند. این شبیه به absl::Status است و دارای همان مجموعه کدهای خطا است. ما از absl::Status استفاده نمی‌کنیم زیرا ABI پایداری ندارد و انتقال آن بین کتابخانه تماس سفارشی بارگذاری شده پویا و خود XLA خطرناک است.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

بافر آرگومان ها و نتایج

XLA از سبک ارسال مقصد برای نتایج استفاده می‌کند: تماس‌های سفارشی (یا هر عملیات XLA دیگری برای آن موضوع) حافظه را برای نتایج تخصیص نمی‌دهد، و در عوض در مقصدهایی که توسط زمان اجرا XLA ارسال می‌شود، می‌نویسند. XLA از تخصیص بافر استاتیک استفاده می کند و بافرها را برای همه مقادیر بر اساس محدوده زنده آنها در زمان کامپایل اختصاص می دهد.

نتایج به کنترل کننده های FFI که در قالب Result<T> پیچیده شده است ارسال می شود که دارای معنایی اشاره گر است: operator-> به پارامتر زیربنایی دسترسی می دهد.

آرگومان ها و نتایج AnyBuffer به پارامترهای بافر تماس سفارشی از هر نوع داده دسترسی می دهد. این زمانی مفید است که تماس سفارشی یک پیاده‌سازی عمومی داشته باشد که برای چندین نوع داده کار می‌کند، و اجرای تماس سفارشی بر اساس نوع داده، ارسال زمان را انجام می‌دهد. AnyBuffer به نوع داده بافر، ابعاد و یک اشاره گر به خود بافر دسترسی می دهد.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

آرگومان ها و نتایج بافر محدود

Buffer اجازه می دهد تا محدودیت هایی را بر روی نوع و رتبه داده بافر اضافه کنید، و اگر آرگومان های زمان اجرا با امضای کنترل کننده FFI مطابقت نداشته باشند، آنها به طور خودکار توسط کنترل کننده بررسی می شوند و یک خطا به زمان اجرا XLA برمی گردند.

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

استدلال ها و نتایج متغیر

اگر تعداد آرگومان‌ها و نتیجه ممکن است در نمونه‌های مختلف یک فراخوانی سفارشی متفاوت باشد، می‌توان آنها را در زمان اجرا با استفاده از RemainingArgs و RemainingRets رمزگشایی کرد.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

آرگومان‌ها و نتایج متغیر را می‌توان پس از آرگومان‌ها و نتایج منظم اعلام کرد، اما آرگومان‌ها و نتایج منظم الزام آور پس از متغیر غیرقانونی است.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

صفات

XLA FFI از رمزگشایی خودکار mlir::DictionaryAttr به عنوان یک custom_call backend_config به آرگومان های کنترل کننده FFI پشتیبانی می کند.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

در این مثال فراخوانی سفارشی دارای یک آرگومان بافر و دو ویژگی است و XLA FFI می تواند به طور خودکار آنها را رمزگشایی کرده و به فراخوانی تعریف شده توسط کاربر ارسال کند.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

ویژگی های Enum تعریف شده توسط کاربر

XLA FFI می تواند به طور خودکار ویژگی های یکپارچه MLIR را در enum های تعریف شده توسط کاربر رمزگشایی کند. کلاس Enum باید همان نوع انتگرال زیرین را داشته باشد و رمزگشایی باید به صراحت در XLA FFI ثبت شود.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

اتصال همه ویژگی های تماس سفارشی

این امکان وجود دارد که به تمام ویژگی های تماس سفارشی به عنوان یک فرهنگ لغت دسترسی داشته باشید و تنها ویژگی هایی را که در زمان اجرا مورد نیاز هستند رمزگشایی کنید.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

ویژگی های ساختاری تعریف شده توسط کاربر

XLA FFI می تواند ویژگی های فرهنگ لغت را به ساختارهای تعریف شده توسط کاربر رمزگشایی کند.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

به عنوان مثال range بالا یک ویژگی mlir::DictionaryAttr است و به جای دسترسی به فیلدهای فرهنگ لغت با نام، می توان آن را به طور خودکار به عنوان یک ساختار C++ رمزگشایی کرد. رمزگشایی باید به طور صریح با یک ماکرو XLA_FFI_REGISTER_STRUCT_ATTR_DECODING ثبت شود (در پشت صحنه، یک الگوی تخصصی در فضای نام ::xla::ffi تعریف می‌کند، بنابراین ماکرو باید به فضای نام جهانی اضافه شود).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

ویژگی های سفارشی را می توان از یک فرهنگ لغت بارگیری کرد، درست مانند هر ویژگی دیگر. در مثال زیر، همه ویژگی‌های فراخوانی سفارشی به‌عنوان یک Dictionary رمزگشایی می‌شوند و یک range با نام قابل دسترسی است.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

یک تماس سفارشی روی CPU ایجاد کنید

شما می توانید یک دستورالعمل HLO ایجاد کنید که نشان دهنده یک تماس سفارشی از طریق API مشتری XLA است. برای مثال، کد زیر از یک فراخوانی سفارشی برای محاسبه A[i] = B[i % 128]+ C[i] در CPU استفاده می‌کند. (البته شما می توانید – و باید! – این کار را با HLO معمولی انجام دهید.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

یک تماس سفارشی در GPU ایجاد کنید

ثبت تماس سفارشی GPU با XLA FFI تقریباً یکسان است، تنها تفاوت این است که برای GPU باید یک جریان پلتفرم اساسی (جریان CUDA یا ROCM) درخواست کنید تا بتوانید هسته را روی دستگاه راه‌اندازی کنید. در اینجا یک مثال CUDA وجود دارد که همان محاسبات ( A[i] = B[i % 128] + C[i] ) را مانند کد CPU بالا انجام می‌دهد.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

ابتدا توجه کنید که تابع فراخوانی سفارشی GPU هنوز هم تابعی است که روی CPU اجرا می شود . تابع do_custom_call CPU مسئول صف بندی کار روی GPU است. در اینجا یک هسته CUDA راه اندازی می کند، اما می تواند کار دیگری نیز انجام دهد، مانند فراخوانی cuBLAS.

آرگومان ها و نتایج نیز روی هاست زنده هستند و عضو داده حاوی اشاره گر به حافظه دستگاه (یعنی GPU) است. بافرهای ارسال شده به کنترل کننده تماس سفارشی شکل بافرهای دستگاه زیرین را دارند، بنابراین تماس سفارشی می تواند پارامترهای راه اندازی هسته را از آنها محاسبه کند.

ارسال تاپل ها به تماس های سفارشی

تماس سفارشی زیر را در نظر بگیرید.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

در هر دو CPU و GPU، یک تاپل در حافظه به عنوان آرایه ای از اشاره گرها نشان داده می شود. هنگامی که XLA فراخوانی های سفارشی را با آرگومان های تاپل یا نتایج فراخوانی می کند، آنها را مسطح می کند و به عنوان آرگومان ها یا نتایج بافر معمولی ارسال می شود.

خروجی ها را به عنوان بافرهای دمایی چند برابر کنید

ورودی‌های چندگانه برای تماس‌های سفارشی راحت هستند، اما به شدت ضروری نیستند. اگر ما از ورودی‌های تاپل برای تماس‌های سفارشی پشتیبانی نمی‌کردیم، همیشه می‌توانید تاپل‌ها را با استفاده از get-tuple-element قبل از ارسال به تماس سفارشی باز کنید.

از سوی دیگر، خروجی‌های تاپل به شما اجازه می‌دهند کارهایی را انجام دهید که در غیر این صورت نمی‌توانید انجام دهید.

دلیل واضح داشتن خروجی‌های تاپل این است که خروجی‌های تاپل به این صورت است که چگونه یک فراخوانی سفارشی (یا هر عملیات XLA دیگری) چندین آرایه مستقل را برمی‌گرداند.

اما بدیهی است که خروجی تاپل نیز راهی برای دادن حافظه موقت تماس سفارشی شماست. بله، یک خروجی می تواند نشان دهنده یک بافر موقت باشد. در نظر بگیرید، یک بافر خروجی این ویژگی را دارد که op می‌تواند روی آن بنویسد، و بعد از نوشتن روی آن می‌تواند از روی آن بخواند. این دقیقاً همان چیزی است که شما از یک بافر دما می خواهید.

در مثال بالا، فرض کنید می‌خواهیم از F32[1024] به عنوان یک بافر موقت استفاده کنیم. سپس HLO را دقیقاً مانند بالا می نویسیم، و به سادگی هرگز تاپل ایندکس 1 خروجی تماس سفارشی را نمی خوانیم.