يوضّح هذا المستند كيفية كتابة واستخدام طلبات XLA المخصّصة باستخدام مكتبة XLA FFI. الاستدعاء المخصّص هو آلية لوصف "عملية" خارجية في وحدة HLO لمترجم XLA (في وقت الترجمة)، وFFI في XLA هي آلية لتسجيل تنفيذ هذه العمليات في XLA (في وقت التشغيل). يشير FFI إلى "واجهة الدوال الخارجية"، وهو عبارة عن مجموعة من واجهات برمجة التطبيقات بلغة C تحدّد واجهة ثنائية (ABI) لكي تستدعي XLA الرموز البرمجية الخارجية المكتوبة بلغات برمجة أخرى. توفّر XLA روابط تتضمّن العناوين فقط لواجهة الوظائف الخارجية (FFI) الخاصة بـ XLA والمكتوبة بلغة C++، ما يخفي جميع التفاصيل المنخفضة المستوى لواجهات برمجة التطبيقات الأساسية المكتوبة بلغة C عن المستخدِم النهائي.
JAX + XLA Custom Calls
راجِع مستندات JAX للاطّلاع على أمثلة شاملة حول دمج المكالمات المخصّصة وXLA FFI مع JAX.
XLA FFI Binding
ربط FFI في XLA هو مواصفات وقت الترجمة لتوقيع الاستدعاء المخصّص:
وسيطات الاستدعاء المخصّص وسماته وأنواعه ومعلَماته الإضافية
التي يتم تمريرها من خلال سياق التنفيذ (أي بث وحدة معالجة الرسومات لخادم الخلفية لوحدة معالجة الرسومات). يمكن ربط ربط XLA FFI بأي عنصر قابل للاستدعاء بلغة C++ (مؤشر دالة، وlambda، وما إلى ذلك) باستخدام توقيع operator() متوافق. تعمل الدالة المعالِجة المنشأة على فك ترميز إطار استدعاء XLA FFI (المحدّد بواسطة واجهة برمجة التطبيقات الثابتة C)، والتحقّق من نوع جميع المَعلمات، وإعادة توجيه
النتائج التي تم فك ترميزها إلى دالة رد الاتصال التي يحدّدها المستخدم.
تعتمد عملية ربط XLA FFI بشكل كبير على برمجة التعريفات الوصفية للنماذج لتتمكّن من تجميع معالج تم إنشاؤه إلى رمز آلة أكثر كفاءة. تكون النفقات العامة لوقت التشغيل بترتيب بضع نانو ثوانٍ لكل مَعلمة طلب مخصّصة.
نقاط تخصيص XLA FFI التي تم تنفيذها كتخصصات نماذج، ويمكن للمستخدمين تحديد كيفية فك تشفير الأنواع المخصّصة، أي أنّه من الممكن تحديد فك تشفير مخصّص لأنواع enum class التي يحدّدها المستخدم.
عرض الأخطاء من المكالمات المخصّصة
يجب أن تعرض عمليات تنفيذ الدوال المخصّصة القيمة xla::ffi::Error للإشارة إلى النجاح أو الخطأ في وقت تشغيل XLA. وهو يشبه absl::Status ويتضمّن مجموعة رموز الأخطاء نفسها. لا نستخدم absl::Status لأنّه لا يتضمّن واجهة ثنائية ثابتة للتطبيق، وسيكون من غير الآمن تمريره بين مكتبة الدوال المخصّصة التي يتم تحميلها ديناميكيًا وXLA نفسها.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
وسيطات المخزن المؤقت ونتائجه
تستخدم XLA نمط تمرير الوجهة للنتائج: لا تخصّص الاستدعاءات المخصّصة (أو أي عمليات XLA أخرى) مساحة تخزين للنتائج، بل تكتب في الوجهات التي يمرّرها وقت تشغيل XLA. تستخدِم XLA عملية تخصيص المخزن المؤقت الثابت، وتخصّص مخازن مؤقتة لجميع القيم استنادًا إلى نطاقاتها النشطة في وقت الترجمة.
يتم تضمين النتائج التي يتم تمريرها إلى معالجات FFI في نموذج Result<T>، والذي يتضمّن دلالات مشابهة للمؤشر: يتيح operator-> الوصول إلى المَعلمة الأساسية.
تتيح وسيطات ونتائج AnyBuffer الوصول إلى مَعلمات المخزن المؤقت للمكالمات المخصّصة
من أي نوع بيانات. يكون هذا الإعداد مفيدًا عندما يكون للطلب المخصّص تنفيذ عام
يعمل مع أنواع بيانات متعدّدة، وينفّذ الطلب المخصّص عملية إرسال
في وقت التشغيل استنادًا إلى نوع البيانات. تتيح السمة AnyBuffer الوصول إلى نوع بيانات المخزن المؤقت وأبعاده ومؤشر إلى المخزن المؤقت نفسه.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
وسيطات ونتائج المخزن المؤقت المقيّد
تسمح Buffer بإضافة قيود على نوع بيانات المخزن المؤقت وعدد السمات، وسيتحقّق المعالج منها تلقائيًا ويعرض خطأ في وقت تشغيل XLA إذا لم تتطابق وسيطات وقت التشغيل مع توقيع معالج FFI.
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
الوسيطات والنتائج المتغيرة
إذا كان عدد الوسيطات والنتائج يمكن أن يختلف في حالات مختلفة من مكالمة مخصّصة، يمكن فك ترميزها في وقت التشغيل باستخدام RemainingArgs وRemainingRets.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
يمكن تعريف الوسيطة المتغيرة والنتائج بعد الوسائط والنتائج العادية، ولكن لا يمكن ربط الوسائط والنتائج العادية بعد الوسيطة المتغيرة.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
السمات
تتيح واجهة XLA FFI فك الترميز التلقائي للرمز mlir::DictionaryAttr الذي تم تمريره كـ
custom_call backend_config إلى وسيطات معالج FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
في هذا المثال، يتضمّن الاستدعاء المخصّص وسيطة واحدة للمخزن المؤقت وسمتَين، ويمكن لواجهة XLA FFI فك ترميزهما تلقائيًا وتمريرهما إلى العنصر القابل للاستدعاء الذي يحدّده المستخدم.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
سمات التعداد المحدّدة من قِبل المستخدم
يمكن أن يفك ترميز XLA FFI تلقائيًا سمات MLIR المتكاملة إلى تعدادات محددة من قِبل المستخدم. يجب أن يحتوي فئة التعداد على نوع عدد صحيح أساسي نفسه، ويجب تسجيل عملية فك الترميز بشكل صريح باستخدام XLA FFI.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
ربط جميع سمات المكالمات المخصّصة
يمكنك الوصول إلى جميع سمات المكالمات المخصّصة كقاموس، وفك ترميز السمات المطلوبة فقط في وقت التشغيل.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
سمات Struct من تحديد المستخدم
يمكن لواجهة XLA FFI فك ترميز سمات القاموس إلى بنى محدّدة من قِبل المستخدم.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
في المثال أعلاه، range هي سمة mlir::DictionaryAttr، وبدلاً من الوصول إلى حقول القاموس بالاسم، يمكن فك ترميزها تلقائيًا كبنية C++. يجب تسجيل عملية فك الترميز بشكل صريح باستخدام ماكرو XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (في الخلفية، يحدّد تخصص نموذج في مساحة الاسم ::xla::ffi، وبالتالي يجب إضافة الماكرو إلى مساحة الاسم العامة).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
يمكن تحميل السمات المخصّصة من قاموس، تمامًا مثل أي سمة أخرى. في المثال أدناه، يتم فك ترميز جميع سمات المكالمات المخصّصة على النحو التالي: a
Dictionary، ويمكن الوصول إلى range بالاسم.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
إنشاء دالة مخصّصة يتم تنفيذها على وحدة المعالجة المركزية
يمكنك إنشاء تعليمات HLO تمثّل طلبًا مخصّصًا من خلال واجهة برمجة التطبيقات الخاصة بعميل XLA. على سبيل المثال، يستخدم الرمز البرمجي التالي طلبًا مخصّصًا لاحتساب A[i] = B[i %
128]+ C[i] على وحدة المعالجة المركزية. (بالطبع يمكنك ذلك، بل يجب عليك ذلك. – do this
with regular HLO.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
إنشاء دالة مخصّصة يتم تنفيذها على وحدة معالجة الرسومات
إنّ عملية تسجيل طلب مخصّص لوحدة معالجة الرسومات باستخدام واجهة XLA FFI هي نفسها تقريبًا، والفرق الوحيد هو أنّه بالنسبة إلى وحدة معالجة الرسومات، عليك طلب بث أساسي للنظام الأساسي (بث CUDA أو ROCM) لتتمكّن من تشغيل النواة على الجهاز. في ما يلي مثال على CUDA
يُجري عملية الحساب نفسها (A[i] = B[i % 128] + C[i]) التي يُجريها رمز وحدة المعالجة المركزية أعلاه.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
يُرجى العلم أولاً أنّ دالة الاتصال المخصّصة لوحدة معالجة الرسومات لا تزال دالة يتم تنفيذها على وحدة المعالجة المركزية. تكون دالة وحدة المعالجة المركزية do_custom_call مسؤولة عن وضع العمل في قائمة الانتظار
على وحدة معالجة الرسومات. في هذا المثال، يتم تشغيل نواة CUDA، ولكن يمكن أيضًا تنفيذ إجراء آخر،
مثل استدعاء cuBLAS.
تتوفّر الوسيطات والنتائج أيضًا على المضيف، ويحتوي عنصر البيانات على مؤشر إلى ذاكرة الجهاز (أي وحدة معالجة الرسومات). تكون المخازن المؤقتة التي يتم تمريرها إلى معالج المكالمات المخصّص على شكل مخازن مؤقتة للأجهزة الأساسية، وبالتالي يمكن للمكالمة المخصّصة احتساب مَعلمات تشغيل النواة منها.
تمرير الصفوف إلى طلبات مخصّصة
ضَع في اعتبارك طلب الاتصال المخصّص التالي.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
في كلّ من وحدة المعالجة المركزية ووحدة معالجة الرسومات، يتم تمثيل المجموعة في الذاكرة كمصفوفة من المؤشرات. عندما تستدعي XLA دوال مخصّصة باستخدام وسيطات أو نتائج من النوع tuple، فإنّها تحوّلها إلى شكل مسطّح وتمرّرها كوسطاء أو نتائج عادية للمخزن المؤقت.
إخراج الصفوف كمخازن مؤقتة
تُعدّ إدخالات الصفوف في المكالمات المخصّصة ميزة مفيدة، ولكنّها ليست ضرورية. إذا لم نكن نسمح بإدخال مجموعات الصفوف في المكالمات المخصّصة، كان بإمكانك دائمًا فك حزم الصفوف باستخدام get-tuple-element قبل تمريرها إلى المكالمة المخصّصة.
من ناحية أخرى، تتيح لك مخرجات الصفوف تنفيذ إجراءات لا يمكنك تنفيذها بطريقة أخرى.
السبب الواضح لاستخدام مخرجات الصفوف هو أنّها الطريقة التي تعرض بها دالة مخصّصة (أو أي عملية XLA أخرى) مصفوفات مستقلة متعددة.
ولكن، بطريقة أقل وضوحًا، يمكن أيضًا استخدام ناتج الصفوف لتوفير ذاكرة مؤقتة لعمليات الاستدعاء المخصّصة. نعم، يمكن أن يمثّل الإخراج مخزنًا مؤقتًا. لنفترض أنّ مخزنًا مؤقتًا للإخراج يحتوي على السمة التي يمكن للعملية الكتابة إليها، ويمكن قراءتها بعد الكتابة إليها. هذا هو بالضبط ما تريده من مخزن مؤقت.
في المثال أعلاه، لنفترض أنّنا أردنا استخدام F32[1024] كمخزن مؤقت.
بعد ذلك، سنكتب HLO كما هو موضح أعلاه، ولن نقرأ ببساطة فهرس المجموعة 1 من ناتج الاستدعاء المخصّص.