שיחות בהתאמה אישית של XLA

במאמר הזה מוסבר איך לכתוב ולהשתמש בקריאות מותאמות אישית ל-XLA באמצעות ספריית XLA FFI. קריאה מותאמת אישית היא מנגנון לתיאור של 'פעולה' חיצונית במודול HLO למהדר XLA (בזמן ההידור), ו-XLA FFI הוא מנגנון לרישום ההטמעה של פעולות כאלה ב-XLA (בזמן הריצה). FFI מייצג "ממשק פונקציות חיצוניות" והוא קבוצה של ממשקי API בשפת C שמגדירים ממשק בינארי (ABI) ל-XLA כדי לקרוא לקוד חיצוני שנכתב בשפות תכנות אחרות. ‫XLA מספקת קשירות של כותרות בלבד ל-XLA FFI שנכתב ב-C++, שמסתיר את כל הפרטים ברמה הנמוכה של ממשקי C API הבסיסיים ממשתמש הקצה.

JAX + XLA Custom Calls

במסמכי JAX יש דוגמאות מלאות לשילוב של קריאות מותאמות אישית ו-XLA FFI עם JAX.

XLA FFI Binding

הקישור XLA FFI הוא מפרט בזמן קומפילציה של חתימת הקריאה המותאמת אישית: ארגומנטים, מאפיינים וסוגים של קריאה מותאמת אישית, ופרמטרים נוספים שמועברים דרך הקשר הביצוע (כלומר, זרם GPU עבור קצה עורפי של GPU). אפשר לקשר את XLA FFI לכל פונקציה שאפשר להפעיל ב-C++ (מצביע פונקציה, למדה וכו') עם חתימה תואמת operator(). ה-handler שנבנה מפענח את מסגרת הקריאה של XLA FFI (מוגדר על ידי C API יציב), בודק את הסוג של כל הפרמטרים ומעביר את התוצאות המפוענחות לקריאה החוזרת שהוגדרה על ידי המשתמש.

הקישור של XLA FFI מסתמך במידה רבה על תכנות מטא-תבניות כדי לאפשר הידור של מטפל שנבנה לקוד מכונה יעיל ביותר. התקורה בזמן הריצה היא בסדר גודל של כמה ננו-שניות לכל פרמטר מותאם אישית של קריאה.

נקודות התאמה אישית של XLA FFI מיושמות כהתמחויות של תבניות, והמשתמשים יכולים להגדיר איך לפענח את הסוגים המותאמים אישית שלהם, כלומר אפשר להגדיר פענוח מותאם אישית לסוגים של enum class שהמשתמשים מגדירים.

החזרת שגיאות משיחות מותאמות אישית

הטמעות של קריאות מותאמות אישית צריכות להחזיר את הערך xla::ffi::Error כדי לסמן הצלחה או שגיאה בזמן הריצה של XLA. היא דומה לפונקציה absl::Status, ויש לה את אותה קבוצה של קודי שגיאה. אנחנו לא משתמשים ב-absl::Status כי אין לו ABI יציב, ולא בטוח להעביר אותו בין ספריית קריאות מותאמת אישית שנטענת באופן דינמי לבין XLA עצמו.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Buffer Arguments And Results

‫XLA משתמש בסגנון העברת יעד לתוצאות: קריאות מותאמות אישית (או כל פעולות אחרות של XLA) לא מקצות זיכרון לתוצאות, אלא כותבות ליעדים שמועברים על ידי זמן הריצה של XLA. ב-XLA נעשה שימוש בהקצאת מאגר סטטית, והמאגרים מוקצים לכל הערכים על סמך טווחי הפעילות שלהם בזמן ההידור.

התוצאות מועברות ל-handlers של FFI כשהן עטופות בתבנית Result<T>, שיש לה סמנטיקה שדומה לסמנטיקה של מצביע: operator-> מאפשרת גישה לפרמטר הבסיסי.

הארגומנטים והתוצאות של AnyBuffer מאפשרים גישה לפרמטרים של מאגר שיחות בהתאמה אישית מכל סוג נתונים. האפשרות הזו שימושית כשהקריאה המותאמת אישית כוללת הטמעה גנרית שמתאימה לכמה סוגי נתונים, וההטמעה של הקריאה המותאמת אישית מפעילה שיגור בזמן ריצה על סמך סוג הנתונים. ‫AnyBuffer נותן גישה לסוג הנתונים של המאגר, למאפיינים ולמצביע למאגר עצמו.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

טיעונים ותוצאות של מאגר מוגבל

Buffer מאפשר להוסיף אילוצים לסוג הנתונים של המאגר ולמספר הממדים, והם ייבדקו אוטומטית על ידי ה-handler ויחזירו שגיאה לזמן הריצה של XLA, אם ארגומנטים של זמן הריצה לא תואמים לחתימה של ה-handler של FFI.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

ארגומנטים ותוצאות עם מספר משתנה של פרמטרים

אם מספר הארגומנטים והתוצאה יכול להיות שונה במופעים שונים של שיחה מותאמת אישית, אפשר לפענח אותם בזמן הריצה באמצעות RemainingArgs ו-RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

אפשר להצהיר על ארגומנטים ותוצאות עם מספר משתנה של ארגומנטים אחרי ארגומנטים ותוצאות רגילים, אבל אי אפשר לקשר ארגומנטים ותוצאות רגילים אחרי ארגומנטים ותוצאות עם מספר משתנה של ארגומנטים.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

מאפיינים

ממשק ה-FFI של XLA תומך בפענוח אוטומטי של mlir::DictionaryAttr שמועבר כ-custom_call backend_config לארגומנטים של handler של FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

בדוגמה הזו, לקריאה המותאמת אישית יש ארגומנט אחד של מאגר ושני מאפיינים, ו-XLA FFI יכול לפענח אותם באופן אוטומטי ולהעביר אותם לפונקציה הניתנת להפעלה שהוגדרה על ידי המשתמש.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

מאפייני enum מוגדרים על ידי משתמש

‫XLA FFI יכול לפענח באופן אוטומטי מאפייני MLIR שלמים לסוגי enum שמוגדרים על ידי המשתמש. למחלקת ה-Enum צריך להיות אותו סוג של מספר שלם בבסיס, וצריך לרשום את הפענוח באופן מפורש ב-XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

קישור כל מאפייני השיחה המותאמים אישית

אפשר לקבל גישה לכל מאפייני השיחות המותאמים אישית כמילון, ולפענח רק את המאפיינים שנדרשים בזמן הריצה.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

מאפייני מבנה נתונים מוגדרים על ידי משתמש

‫XLA FFI יכול לפענח מאפייני מילון למבנים מוגדרים על ידי המשתמש.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

בדוגמה שלמעלה, range הוא מאפיין mlir::DictionaryAttr, ובמקום לגשת לשדות המילון לפי שם, אפשר לפענח אותו אוטומטית כמבנה C++. צריך לרשום את הפענוח באופן מפורש באמצעות מאקרו XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (מאחורי הקלעים, המאקרו מגדיר התמחות של תבנית במרחב השמות ::xla::ffi, ולכן צריך להוסיף את המאקרו למרחב השמות הגלובלי).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

אפשר לטעון מאפיינים מותאמים אישית ממילון, בדיוק כמו כל מאפיין אחר. בדוגמה שלמטה, כל מאפייני השיחה המותאמים אישית מפוענחים כ-Dictionary, ואפשר לגשת ל-range לפי שם.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

יצירת קריאה מותאמת אישית ל-CPU

אפשר ליצור הוראת HLO שמייצגת קריאה מותאמת אישית באמצעות ה-API של הלקוח של XLA. לדוגמה, הקוד הבא משתמש בהפעלה בהתאמה אישית כדי לחשב את A[i] = B[i % 128]+ C[i] במעבד. (כמובן שאתם יכולים – וכדאי לכם! – do this with regular HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

יצירת קריאה בהתאמה אישית ב-GPU

הרישום של קריאה מותאמת אישית של GPU באמצעות XLA FFI כמעט זהה, וההבדל היחיד הוא שב-GPU צריך לבקש מקור נתונים של פלטפורמה בסיסית (CUDA או ROCM) כדי להפעיל את ליבת המערכת במכשיר. זו דוגמה ל-CUDA שמבצעת את אותו חישוב (A[i] = B[i % 128] + C[i]) כמו קוד המעבד שלמעלה.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

חשוב לשים לב שהפונקציה המותאמת אישית של קריאה ל-GPU עדיין מבוצעת ב-CPU. הפונקציה do_custom_call CPU אחראית להוספת עבודה לתור ב-GPU. כאן מופעל ליבת CUDA, אבל אפשר גם לבצע פעולה אחרת, כמו קריאה ל-cuBLAS.

הארגומנטים והתוצאות נמצאים גם במארח, וחבר הנתונים מכיל מצביע לזיכרון המכשיר (כלומר, GPU). מאגרי הנתונים הזמניים שמועברים ל-custom call handler הם בצורה של מאגרי הנתונים הזמניים של המכשיר, כך שהשיחה המותאמת אישית יכולה לחשב מהם את פרמטרים של הפעלת ליבת המערכת.

העברת טאפלים לקריאות בהתאמה אישית

נבחן את הקריאה הבאה בהתאמה אישית.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

גם במעבד וגם במעבד ה-GPU, טאפל מיוצג בזיכרון כמערך של מצביעים. כש-XLA קורא לפונקציות מותאמות אישית עם ארגומנטים או תוצאות של tuple, הוא משטח אותם ומעביר אותם כארגומנטים או תוצאות של מאגר רגיל.

הפלט של ה-tuple כמאגרי נתונים זמניים

הזנת טאפלים לקריאות מותאמות אישית היא נוחה, אבל היא לא הכרחית. אם לא הייתה תמיכה בהזנת טאפלים לקריאות מותאמות אישית, תמיד הייתה אפשרות לפתוח את הטאפלים באמצעות get-tuple-element לפני העברתם לקריאה המותאמת אישית.

לעומת זאת, פלט של טאפל מאפשר לכם לעשות דברים שלא הייתם יכולים לעשות אחרת.

הסיבה הברורה לשימוש בפלט מסוג tuple היא שזו הדרך שבה קריאה מותאמת אישית (או כל פעולה אחרת של XLA) מחזירה כמה מערכים עצמאיים.

אבל בצורה פחות ברורה, פלט של tuple הוא גם דרך לתת לזיכרון הזמני של השיחה המותאמת אישית. כן, פלט יכול לייצג מאגר זמני. לדוגמה, למאגר פלט יש את המאפיין שאפשר לכתוב אליו את הפלט של הפעולה, ואפשר לקרוא ממנו אחרי שכותבים אליו. זה בדיוק מה שרוצים ממאגר זמני.

בדוגמה שלמעלה, נניח שאנחנו רוצים להשתמש ב-F32[1024] כמאגר זמני. אחר כך נכתוב את ה-HLO בדיוק כמו בדוגמה שלמעלה, ופשוט לא נקרא את אינדקס 1 של ה-tuple בפלט של הקריאה המותאמת אישית.