XLA Özel Çağrıları

Bu belgede, XLA FFI kitaplığını kullanarak XLA özel çağrılarını yazma ve kullanma açıklanmaktadır. Özel çağrı, HLO modülündeki harici bir "işlemi" XLA derleyicisine (derleme zamanında) açıklamak için kullanılan bir mekanizmadır. XLA FFI ise bu tür işlemlerin uygulamasını XLA'ya (çalışma zamanında) kaydetmek için kullanılan bir mekanizmadır. FFI, "yabancı işlev arayüzü" anlamına gelir ve XLA'nın diğer programlama dillerinde yazılmış harici kodu çağırması için ikili arayüz (ABI) tanımlayan bir C API'leri kümesidir. XLA, C++ ile yazılmış XLA FFI için yalnızca başlık içeren bağlamalar sağlar. Bu bağlamalar, temel C API'lerinin tüm düşük düzey ayrıntılarını son kullanıcıdan gizler.

JAX + XLA Özel Çağrıları

Özel çağrıları ve XLA FFI'yi JAX ile entegre etmeye yönelik uçtan uca örnekler için JAX dokümanlarına bakın.

XLA FFI Bağlama

XLA FFI bağlama, derleme zamanında özel çağrı imzası belirtimidir: özel çağrı bağımsız değişkenleri, özellikleri ve türleri ile yürütme bağlamı (ör. GPU arka ucu için GPU akışı) aracılığıyla iletilen ek parametreler. XLA FFI bağlaması, uyumlu operator() imzasıyla herhangi bir C++ çağrılabilirine (işlev işaretçisi, lambda vb.) bağlanabilir. Oluşturulan işleyici, XLA FFI çağrı çerçevesinin (kararlı C API'si tarafından tanımlanır) kodunu çözer, tüm parametrelerin türünü kontrol eder ve kodu çözülen sonuçları kullanıcı tanımlı geri çağırmaya iletir.

XLA FFI bağlaması, oluşturulan işleyiciyi en verimli makine koduna derleyebilmek için şablon metaprogramlamasına büyük ölçüde bağlıdır. Çalışma zamanı ek yükleri, her özel çağrı parametresi için birkaç nanosaniye sırasındadır.

Şablon uzmanlıkları olarak uygulanan XLA FFI özelleştirme noktaları ve kullanıcılar özel türlerinin nasıl kod çözüleceğini tanımlayabilir. Örneğin, kullanıcı tanımlı enum class türleri için özel kod çözme tanımlamak mümkündür.

Özel Çağrılardan Geri Dönen Hatalar

Özel çağrı uygulamaları, XLA çalışma zamanına başarı veya hata sinyali vermek için xla::ffi::Error değerini döndürmelidir. absl::Status ile benzerdir ve aynı hata kodları kümesine sahiptir. Kararlı bir ABI'ye sahip olmadığı ve dinamik olarak yüklenen özel çağrı kitaplığı ile XLA arasında geçirilmesinin güvenli olmayacağı için absl::Status kullanmıyoruz.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Buffer Arguments And Results

XLA, sonuçlar için hedef geçirme stilini kullanır: Özel çağrılar (veya bu bağlamdaki diğer XLA işlemleri) sonuçlar için bellek ayırmaz ve bunun yerine XLA çalışma zamanı tarafından geçirilen hedeflere yazar. XLA, statik arabellek ataması kullanır ve derleme zamanında tüm değerler için arabellekleri canlı aralıklarına göre ayırır.

FFI işleyicilerine iletilen sonuçlar, işaretçi benzeri semantiğe sahip bir Result<T> şablonuna sarılır: operator->, temel parametreye erişim sağlar.

AnyBuffer bağımsız değişkenleri ve sonuçları, herhangi bir veri türündeki özel arama arabelleği parametrelerine erişim sağlar. Bu, özel çağrının birden fazla veri türü için çalışan genel bir uygulaması olduğunda ve özel çağrı uygulaması, veri türüne göre çalışma zamanı dağıtımı yaptığında kullanışlıdır. AnyBuffer, arabellek veri türüne, boyutlarına ve arabelleğin kendisine yönelik bir işaretçiye erişim sağlar.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Kısıtlanmış arabellek bağımsız değişkenleri ve sonuçları

Buffer, arabellek veri türüne ve boyut sayısına kısıtlamalar eklenmesine olanak tanır. Bu kısıtlamalar işleyici tarafından otomatik olarak kontrol edilir ve çalışma zamanı bağımsız değişkenleri FFI işleyici imzasıyla eşleşmezse XLA çalışma zamanına bir hata döndürülür.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Değişken Sayılı Bağımsız Değişkenler ve Sonuçlar

Özel bir çağrının farklı örneklerinde bağımsız değişkenlerin ve sonucun sayısı farklı olabilir. Bu durumda, RemainingArgs ve RemainingRets kullanılarak çalışma zamanında kodları çözülebilir.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Değişken sayıda bağımsız değişkenler ve sonuçlar, normal bağımsız değişkenlerden ve sonuçlardan sonra bildirilebilir. Ancak değişken sayıda bağımsız değişkenden sonra normal bağımsız değişkenleri ve sonuçları bağlamak yasa dışıdır.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Özellikler

XLA FFI, mlir::DictionaryAttr olarak iletilen custom_call backend_config değerinin FFI işleyici bağımsız değişkenlerine otomatik olarak kod çözülmesini destekler.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Bu örnekte, özel çağrının tek bir arabellek bağımsız değişkeni ve iki özelliği vardır. XLA FFI, bunları otomatik olarak kod çözebilir ve kullanıcı tanımlı çağrılabilir öğeye iletebilir.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Kullanıcı Tanımlı Enum Özellikleri

XLA FFI, tam sayı MLIR özelliklerini kullanıcı tanımlı numaralandırmalara otomatik olarak çözebilir. Numaralandırma sınıfı, aynı temel tam sayı türüne sahip olmalı ve kod çözme işlemi XLA FFI'ye açıkça kaydedilmelidir.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Tüm özel arama özelliklerini bağlama

Tüm özel çağrı özelliklerine sözlük olarak erişmek ve yalnızca çalışma zamanında ihtiyaç duyulan özellikleri geç kod çözmek mümkündür.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Kullanıcı Tanımlı Yapı Özellikleri

XLA FFI, sözlük özelliklerini kullanıcı tanımlı yapılara çözebilir.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Yukarıdaki örnekte range bir mlir::DictionaryAttr özelliğidir ve sözlük alanlarına adıyla erişmek yerine C++ yapısı olarak otomatik olarak kod çözülebilir. Kod çözme, bir XLA_FFI_REGISTER_STRUCT_ATTR_DECODING makrosuyla açıkça kaydedilmelidir (arka planda ::xla::ffi ad alanında bir şablon uzmanlığı tanımlar, bu nedenle makro global ad alanına eklenmelidir).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Özel özellikler, diğer tüm özellikler gibi sözlükten yüklenebilir. Aşağıdaki örnekte, tüm özel arama özellikleri Dictionary olarak kod çözülür ve range adına göre erişilebilir.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

CPU'da özel çağrı oluşturma

XLA'nın istemci API'si aracılığıyla özel bir çağrıyı temsil eden bir HLO talimatı oluşturabilirsiniz. Örneğin, aşağıdaki kod, CPU'da A[i] = B[i % 128]+ C[i] değerini hesaplamak için özel bir çağrı kullanır. (Elbette yapabilirsiniz ve yapmalısınız. – do this with regular HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

GPU'da özel çağrı oluşturma

XLA FFI ile GPU özel çağrı kaydı neredeyse aynıdır. Tek fark, cihazda çekirdek başlatabilmek için GPU'da temel bir platform akışı (CUDA veya ROCM akışı) istemeniz gerekmesidir. Aşağıda, yukarıdaki CPU koduyla aynı hesaplamayı (A[i] = B[i % 128] + C[i]) yapan bir CUDA örneği verilmiştir.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Öncelikle GPU özel çağrı işlevinin CPU'da yürütülen bir işlev olmaya devam ettiğini unutmayın. do_custom_call CPU işlevi, GPU'da işleri sıraya almaktan sorumludur. Burada bir CUDA çekirdeği başlatılıyor ancak cuBLAS'ı çağırma gibi başka bir işlem de yapılabilir.

Bağımsız değişkenler ve sonuçlar da ana makinede bulunur.Veri üyesi, cihaz (ör. GPU) belleğine yönelik bir işaretçi içerir. Özel çağrı işleyiciye iletilen arabellekler, temel alınan cihaz arabelleklerinin şekline sahiptir. Bu nedenle, özel çağrı, çekirdek başlatma parametrelerini bunlardan hesaplayabilir.

Özel çağrılara demet iletme

Aşağıdaki özel aramayı göz önünde bulundurun.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Hem CPU hem de GPU'da demet, bellekte bir işaretçi dizisi olarak temsil edilir. XLA, demet bağımsız değişkenleri veya sonuçları içeren özel çağrıları çağırdığında bunları düzleştirir ve normal arabellek bağımsız değişkenleri veya sonuçları olarak iletir.

Geçici arabellek olarak demet çıkışları

Özel çağrılara yönelik demet girişleri kolaylık sağlar ancak kesinlikle gerekli değildir. Özel çağrılara demet girişlerini desteklemeseydik, özel çağrıya iletmeden önce her zaman get-tuple-element kullanarak demetleri açabilirdiniz.

Öte yandan, demet çıkışları, başka türlü yapamayacağınız şeyleri yapmanıza olanak tanır.

Demet çıkışlarının kullanılmasının en belirgin nedeni, özel bir çağrının (veya başka bir XLA işleminin) birden fazla bağımsız dizi döndürme şeklinin demet çıkışları olmasıdır.

Ancak daha az belirgin bir şekilde, demet çıkışı da özel çağrınıza geçici bellek vermenin bir yoludur. Evet, bir çıkış geçici arabelleği temsil edebilir. Örneğin, bir çıkış arabelleği, işlemin üzerine yazabileceği ve yazıldıktan sonra okuyabileceği özelliğe sahiptir. Geçici arabellekten tam olarak bunu beklersiniz.

Yukarıdaki örnekte, F32[1024] öğesini geçici arabellek olarak kullanmak istediğimizi varsayalım. Ardından, HLO'yu yukarıdaki gibi yazarız ve özel çağrının çıkışının 1. demet dizinini hiçbir zaman okumayız.