Panggilan Kustom XLA

Dokumen ini menjelaskan cara menulis dan menggunakan panggilan kustom XLA menggunakan library XLA FFI. Panggilan kustom adalah mekanisme untuk mendeskripsikan "operasi" eksternal dalam modul HLO ke compiler XLA (pada waktu kompilasi), dan XLA FFI adalah mekanisme untuk mendaftarkan penerapan operasi tersebut dengan XLA (pada waktu runtime). FFI adalah singkatan dari "foreign function interface" dan merupakan sekumpulan C API yang menentukan antarmuka biner (ABI) untuk XLA agar dapat memanggil kode eksternal yang ditulis dalam bahasa pemrograman lain. XLA menyediakan binding khusus header untuk XLA FFI yang ditulis dalam C++, yang menyembunyikan semua detail tingkat rendah dari API C yang mendasarinya dari pengguna akhir.

Panggilan Kustom JAX + XLA

Lihat dokumentasi JAX untuk contoh menyeluruh tentang mengintegrasikan panggilan kustom dan XLA FFI dengan JAX.

Binding FFI XLA

Binding FFI XLA adalah spesifikasi waktu kompilasi tanda tangan panggilan kustom: argumen panggilan kustom, atribut dan jenisnya, serta parameter tambahan yang diteruskan melalui konteks eksekusi (yaitu, aliran GPU untuk backend GPU). Binding XLA FFI dapat diikat ke C++ yang dapat dipanggil (pointer fungsi, lambda, dll.) dengan tanda tangan operator() yang kompatibel. Handler yang dibuat mendekode frame panggilan XLA FFI (ditentukan oleh C API yang stabil), memeriksa jenis semua parameter, dan meneruskan hasil yang telah didekode ke callback yang ditentukan pengguna.

Binding FFI XLA sangat mengandalkan metaprogramming template agar dapat mengompilasi handler yang dibuat ke kode mesin yang paling efisien. Overhead waktu proses berkisar beberapa nanodetik untuk setiap parameter panggilan kustom.

Titik penyesuaian FFI XLA yang diterapkan sebagai spesialisasi template, dan pengguna dapat menentukan cara mendekode jenis kustom mereka, yaitu, dimungkinkan untuk menentukan decoding kustom untuk jenis enum class yang ditentukan pengguna.

Mengembalikan Error dari Panggilan Kustom

Implementasi panggilan kustom harus menampilkan nilai xla::ffi::Error untuk memberi sinyal keberhasilan atau error ke runtime XLA. Error ini mirip dengan absl::Status, dan memiliki kumpulan kode error yang sama. Kita tidak menggunakan absl::Status karena tidak memiliki ABI yang stabil dan akan tidak aman untuk meneruskannya antara library panggilan kustom yang dimuat secara dinamis, dan XLA itu sendiri.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumen dan Hasil Buffer

XLA menggunakan gaya penerusan tujuan untuk hasil: panggilan kustom (atau operasi XLA lainnya) tidak mengalokasikan memori untuk hasil, dan sebagai gantinya, menulis ke tujuan yang diteruskan oleh runtime XLA. XLA menggunakan penetapan buffer statis, dan mengalokasikan buffer untuk semua nilai berdasarkan rentang aktifnya pada waktu kompilasi.

Hasil yang diteruskan ke pengendali FFI yang di-wrap ke dalam template Result<T>, yang memiliki semantik seperti pointer: operator-> memberikan akses ke parameter pokok.

Argumen dan hasil AnyBuffer memberikan akses ke parameter buffer panggilan kustom dari jenis data apa pun. Hal ini berguna saat panggilan kustom memiliki implementasi generik yang berfungsi untuk beberapa jenis data, dan implementasi panggilan kustom menjalankan pengiriman waktu proses berdasarkan jenis data. AnyBuffer memberikan akses ke jenis data buffer, dimensi, dan pointer ke buffer itu sendiri.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumen dan Hasil Buffer yang Dibatasi

Buffer memungkinkan penambahan batasan pada jenis data buffer dan jumlah dimensi, dan batasan tersebut akan otomatis diperiksa oleh handler dan menampilkan error ke runtime XLA, jika argumen runtime tidak cocok dengan tanda tangan handler FFI.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumen dan Hasil Variadik

Jika jumlah argumen dan hasil dapat berbeda dalam berbagai instance panggilan kustom, argumen dan hasil tersebut dapat didekode saat runtime menggunakan RemainingArgs dan RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Argumen dan hasil variadik dapat dideklarasikan setelah argumen dan hasil reguler, tetapi mengikat argumen dan hasil reguler setelah argumen dan hasil variadik tidak diizinkan.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atribut

XLA FFI mendukung decoding otomatis mlir::DictionaryAttr yang diteruskan sebagai custom_call backend_config ke dalam argumen handler FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dalam contoh ini, panggilan kustom memiliki satu argumen buffer dan dua atribut, dan XLA FFI dapat otomatis mendekode dan meneruskannya ke callable yang ditentukan pengguna.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atribut Enum yang Ditentukan Pengguna

FFI XLA dapat otomatis mendekode atribut MLIR integral menjadi enum yang ditentukan pengguna. Class enum harus memiliki jenis integral pokok yang sama, dan decoding harus didaftarkan secara eksplisit dengan XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Mengikat Semua Atribut Panggilan Kustom

Anda dapat memperoleh akses ke semua atribut panggilan kustom sebagai kamus dan mendekode secara lambat hanya atribut yang diperlukan saat runtime.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atribut Struct yang Ditentukan Pengguna

FFI XLA dapat mendekode atribut kamus ke dalam struct yang ditentukan pengguna.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dalam contoh di atas, range adalah atribut mlir::DictionaryAttr, dan bukan mengakses kolom kamus berdasarkan nama, kolom tersebut dapat otomatis didekodekan sebagai struct C++. Dekode harus didaftarkan secara eksplisit dengan makro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (di balik layar, makro ini menentukan spesialisasi template di namespace ::xla::ffi, sehingga makro harus ditambahkan ke namespace global).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Atribut kustom dapat dimuat dari kamus, seperti atribut lainnya. Dalam contoh di bawah, semua atribut panggilan kustom didekodekan sebagai Dictionary, dan range dapat diakses berdasarkan nama.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Membuat panggilan kustom di CPU

Anda dapat membuat instruksi HLO yang merepresentasikan panggilan kustom melalui XLA Client API. Misalnya, kode berikut menggunakan panggilan kustom untuk menghitung A[i] = B[i % 128]+ C[i] di CPU. (Tentu saja Anda bisa – dan harus melakukannya! – lakukan ini dengan HLO reguler.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Membuat panggilan kustom di GPU

Pendaftaran panggilan kustom GPU dengan XLA FFI hampir identik, satu-satunya perbedaan adalah bahwa untuk GPU, Anda perlu meminta aliran platform pokok (aliran CUDA atau ROCM) agar dapat meluncurkan kernel di perangkat. Berikut adalah contoh CUDA yang melakukan komputasi yang sama (A[i] = B[i % 128] + C[i]) dengan kode CPU di atas.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Perhatikan terlebih dahulu bahwa fungsi panggilan kustom GPU masih merupakan fungsi yang dieksekusi di CPU. Fungsi CPU do_custom_call bertanggung jawab untuk mengantrekan pekerjaan di GPU. Di sini, kernel CUDA diluncurkan, tetapi kernel tersebut juga dapat melakukan hal lain, seperti memanggil cuBLAS.

Argumen dan hasil juga berada di host, dan anggota data berisi pointer ke memori perangkat (yaitu GPU). Buffer yang diteruskan ke handler panggilan kustom memiliki bentuk buffer perangkat pokok, sehingga panggilan kustom dapat menghitung parameter peluncuran kernel dari buffer tersebut.

Meneruskan tuple ke panggilan kustom

Pertimbangkan panggilan kustom berikut.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Di CPU dan GPU, tuple direpresentasikan dalam memori sebagai array pointer. Saat memanggil panggilan kustom dengan argumen atau hasil tuple, XLA akan meratakannya dan meneruskannya sebagai argumen atau hasil buffer reguler.

Output tuple sebagai buffer sementara

Input tuple ke panggilan kustom adalah kemudahan, tetapi tidak benar-benar diperlukan. Jika kami tidak mendukung input tuple ke panggilan kustom, Anda selalu dapat membuka tuple menggunakan get-tuple-element sebelum meneruskannya ke panggilan kustom.

Di sisi lain, output tuple memungkinkan Anda melakukan hal-hal yang tidak dapat Anda lakukan sebelumnya.

Alasan yang jelas untuk memiliki output tuple adalah bahwa output tuple adalah cara panggilan kustom (atau operasi XLA lainnya) menampilkan beberapa array independen.

Namun, yang kurang jelas, output tuple juga merupakan cara untuk memberikan memori sementara panggilan kustom Anda. Ya, output dapat merepresentasikan buffer sementara. Misalnya, buffer output memiliki properti yang dapat ditulis oleh operasi, dan dapat dibaca setelah ditulis. Itulah yang Anda inginkan dari buffer sementara.

Dalam contoh di atas, misalkan kita ingin menggunakan F32[1024] sebagai buffer sementara. Kemudian, kita akan menulis HLO seperti di atas, dan kita tidak akan pernah membaca indeks tuple 1 dari output panggilan kustom.