שיחות XLA בהתאמה אישית

במסמך הזה נסביר איך לכתוב ולהשתמש בשיחות מותאמות אישית של XLA באמצעות ספריית XLA FFI. קריאה בהתאמה אישית היא מנגנון שמתאר "פעולה" חיצונית במודול HLO ומהדר שלה (בזמן הידור) ו-XA FFI הוא מנגנון לרישום הטמעה של פעולות כאלה באמצעות XLA (בזמן הריצה). FFI הוא קיצור של 'ממשק פונקציה זרה'. זוהי קבוצה של ממשקי API ל-C שמגדירים ממשק בינארי (ABI) שבאמצעותו XLA יקרא קוד חיצוני שנכתב בשפות תכנות אחרות. XLA מספק קישורי כותרת בלבד ל- XLA FFI שכתוב ב-C++, שמסתיר ממשתמש הקצה את כל הפרטים ברמה הנמוכה לגבי ממשקי ה-API הבסיסיים של C.

יצירת שיחה בהתאמה אישית במעבד (CPU)

אפשר ליצור הוראה ל-HLO שמייצגת קריאה בהתאמה אישית דרך ה-API של הלקוח של XLA. לדוגמה, הקוד הבא משתמש בקריאה בהתאמה אישית כדי לחשב את A[i] = B[i % 128]+ C[i] במעבד (CPU). (ברור שאפשר – וכדאי מאוד! – עשו זאת באמצעות HLO רגיל.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

יצירת שיחה בהתאמה אישית ב-GPU

רישום השיחות בהתאמה אישית ל-GPU באמצעות XLA FFI הוא כמעט זהה. ההבדל היחיד הוא שב-GPU יש לבקש זרם פלטפורמה בסיסי (CUDA או ROCM) כדי להפעיל את הליבה במכשיר. הנה דוגמה ל-CUDA שמבצעת את אותה חישוב (A[i] = B[i % 128] + C[i]) כמו קוד המעבד (CPU) שלמעלה.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

קודם כול שימו לב שפונקציית הקריאה המותאמת אישית של GPU היא עדיין פונקציה שמופעלת במעבד (CPU). הפונקציה של המעבד (CPU) do_custom_call אחראית על הוספת העבודה ל-GPU בתור. כאן הוא מפעיל ליבת CUDA, אבל הוא יכול גם לעשות משהו אחר, כמו לקרוא ל-cuBLAS.

הארגומנטים והתוצאות נמצאים גם במארח, וחבר הנתונים מכיל זיכרון מצביע למכשיר (כלומר GPU). המבנה של מאגרי נתונים זמניים שמועברים ל-handler בהתאמה אישית של שיחות הוא בצורת של מאגרי האחסון הזמני של המכשיר, כך שהקריאה בהתאמה אישית יכולה לחשב מהם פרמטרים של הפעלת הליבה.

העברת תגים לשיחות בהתאמה אישית

כדאי לזכור את הקריאה המותאמת אישית הבאה.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

גם במעבד (CPU) וגם ב-GPU, צמד מיוצג בזיכרון כמערך של מצביעים. כש-XLA קוראת לקריאות בהתאמה אישית עם ארגומנטים או תוצאות, היא משטחת אותם ומעבירה כארגומנטים רגילים או כתוצאות של מאגרי נתונים זמניים.

פלט זוגי כמאגרי אנרגיה זמניים

אומנם אופציות קלט טבלאיות לשיחות מותאמות אישית הן נוחות, אבל לא חובה. אם לא תמכו בקלט כפול לקריאות מותאמות אישית, תמיד תוכלו לפתוח את ה-tupplement באמצעות רכיב get-tuplement לפני שמעבירים אותם לקריאה בהתאמה אישית.

מצד שני, פלטים כפולים מאפשרים לבצע פעולות שלא ניתן היה לבצע אחרת.

הסיבה הברורה לפלטים כפולים היא שפלטים כפולים הם האופן שבו קריאה מותאמת אישית (או כל פעולת XLA אחרת) מחזירה מערכים בלתי תלויים מרובים.

אבל פחות מובן שהפלט הכפול הוא גם דרך להעניק בהתאמה אישית את הזיכרון הזמני של השיחה. כן, פלט יכול לייצג מאגר נתונים זמני. שימו לב: מאגר נתונים זמני של פלט כולל את המאפיין שמשתמש יכול לכתוב בו, והוא יכול לקרוא ממנו אחרי שהוא נכתב. זה בדיוק מה שאתה רוצה ממאגר זמני.

בדוגמה שלמעלה, נניח שרצינו להשתמש בפונקציה F32[1024] כמאגר זמני. לאחר מכן נכתוב את ה-HLO בדיוק כמו למעלה, ופשוט לא נקרא את אינדקס tuple 1 של הפלט של הקריאה בהתאמה אישית.