Appels personnalisés XLA

Ce document explique comment écrire et utiliser des appels personnalisés XLA à l'aide de la bibliothèque XLA FFI. L'appel personnalisé est un mécanisme permettant de décrire une "opération" externe dans le module HLO au compilateur XLA (au moment de la compilation), et XLA FFI est un mécanisme permettant d'enregistrer l'implémentation de ces opérations avec XLA (au moment de l'exécution). FFI signifie "foreign function interface" (interface de fonction étrangère). Il s'agit d'un ensemble d'API C qui définissent une interface binaire (ABI) pour que XLA puisse appeler du code externe écrit dans d'autres langages de programmation. XLA fournit des liaisons d'en-tête uniquement pour XLA FFI écrit en C++, ce qui masque tous les détails de bas niveau des API C sous-jacentes à l'utilisateur final.

Appels personnalisés JAX + XLA

Consultez la documentation JAX pour obtenir des exemples de bout en bout d'intégration d'appels personnalisés et de XLA FFI avec JAX.

Liaison FFI XLA

La liaison FFI XLA est une spécification au moment de la compilation de la signature d'appel personnalisé : arguments d'appel personnalisé, attributs et leurs types, et paramètres supplémentaires transmis via le contexte d'exécution (c'est-à-dire le flux GPU pour le backend GPU). La liaison XLA FFI peut être liée à n'importe quel appelable C++ (pointeur de fonction, lambda, etc.) avec une signature operator() compatible. Le gestionnaire construit décode le frame d'appel XLA FFI (défini par l'API C stable), vérifie le type de tous les paramètres et transmet les résultats décodés au rappel défini par l'utilisateur.

La liaison XLA FFI repose fortement sur la métaprogrammation de modèles pour pouvoir compiler le gestionnaire construit dans le code machine le plus efficace. Les frais généraux d'exécution sont de l'ordre de quelques nanosecondes pour chaque paramètre d'appel personnalisé.

Points de personnalisation FFI XLA implémentés en tant que spécialisations de modèle.Les utilisateurs peuvent définir la manière de décoder leurs types personnalisés. Il est donc possible de définir un décodage personnalisé pour les types enum class définis par l'utilisateur.

Renvoi d'erreurs à partir d'appels personnalisés

Les implémentations d'appels personnalisés doivent renvoyer une valeur xla::ffi::Error pour signaler le succès ou l'échec au runtime XLA. Elle est semblable à absl::Status et possède le même ensemble de codes d'erreur. Nous n'utilisons pas absl::Status, car il ne dispose pas d'une ABI stable et il serait dangereux de le transmettre entre la bibliothèque d'appel personnalisée chargée dynamiquement et XLA lui-même.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Arguments et résultats du tampon

XLA utilise le style de transmission de destination pour les résultats : les appels personnalisés (ou toute autre opération XLA d'ailleurs) n'allouent pas de mémoire pour les résultats, mais écrivent plutôt dans les destinations transmises par le runtime XLA. XLA utilise l'attribution de tampon statique et alloue des tampons pour toutes les valeurs en fonction de leurs plages de durée de vie au moment de la compilation.

Les résultats transmis aux gestionnaires FFI sont encapsulés dans un modèle Result<T>, qui a une sémantique de type pointeur : operator-> donne accès au paramètre sous-jacent.

Les arguments et les résultats AnyBuffer donnent accès aux paramètres de mémoire tampon d'appel personnalisés de n'importe quel type de données. Cela est utile lorsque l'appel personnalisé dispose d'une implémentation générique qui fonctionne pour plusieurs types de données et que l'implémentation de l'appel personnalisé effectue un dispatching au moment de l'exécution en fonction du type de données. AnyBuffer donne accès au type de données du tampon, à ses dimensions et à un pointeur vers le tampon lui-même.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Arguments et résultats de mémoire tampon contraints

Buffer permet d'ajouter des contraintes sur le type de données du tampon et le nombre de dimensions. Elles seront automatiquement vérifiées par le gestionnaire et renverront une erreur à l'environnement d'exécution XLA si les arguments d'exécution ne correspondent pas à la signature du gestionnaire FFI.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Arguments et résultats variadiques

Si le nombre d'arguments et de résultats peut varier dans différentes instances d'un appel personnalisé, ils peuvent être décodés au moment de l'exécution à l'aide de RemainingArgs et RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Les arguments et résultats variadiques peuvent être déclarés après les arguments et résultats réguliers. Toutefois, il est illégal de lier des arguments et résultats réguliers après un argument ou résultat variadique.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Attributs

XLA FFI permet le décodage automatique de mlir::DictionaryAttr transmis en tant que custom_call backend_config dans les arguments du gestionnaire FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dans cet exemple, l'appel personnalisé comporte un seul argument de tampon et deux attributs. XLA FFI peut les décoder automatiquement et les transmettre à la fonction appelable définie par l'utilisateur.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Attributs d'énumération définis par l'utilisateur

XLA FFI peut décoder automatiquement les attributs MLIR intégraux en énumérations définies par l'utilisateur. La classe d'énumération doit avoir le même type intégral sous-jacent, et le décodage doit être explicitement enregistré avec XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Liaison de tous les attributs d'appel personnalisés

Il est possible d'accéder à tous les attributs d'appel personnalisés sous forme de dictionnaire et de décoder de manière différée uniquement les attributs nécessaires au moment de l'exécution.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Attributs Struct définis par l'utilisateur

XLA FFI peut décoder les attributs de dictionnaire en structs définis par l'utilisateur.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dans l'exemple ci-dessus, range est un attribut mlir::DictionaryAttr. Au lieu d'accéder aux champs du dictionnaire par leur nom, il peut être automatiquement décodé en tant que struct C++. Le décodage doit être explicitement enregistré avec une macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (en arrière-plan, il définit une spécialisation de modèle dans l'espace de noms ::xla::ffi, la macro doit donc être ajoutée à l'espace de noms global).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Les attributs personnalisés peuvent être chargés à partir d'un dictionnaire, comme n'importe quel autre attribut. Dans l'exemple ci-dessous, tous les attributs d'appel personnalisés décodés en tant que Dictionary et range sont accessibles par leur nom.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Créer un appel personnalisé sur le processeur

Vous pouvez créer une instruction HLO qui représente un appel personnalisé via l'API cliente de XLA. Par exemple, le code suivant utilise un appel personnalisé pour calculer A[i] = B[i % 128]+ C[i] sur le processeur. (Bien sûr, vous pouvez et devez ! – do this with regular HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Créer un appel personnalisé sur GPU

L'enregistrement d'appel personnalisé du GPU avec XLA FFI est presque identique. La seule différence est que pour le GPU, vous devez demander un flux de plate-forme sous-jacent (flux CUDA ou ROCM) pour pouvoir lancer le noyau sur l'appareil. Voici un exemple CUDA qui effectue le même calcul (A[i] = B[i % 128] + C[i]) que le code CPU ci-dessus.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Notez tout d'abord que la fonction d'appel personnalisé du GPU est toujours une fonction exécutée sur le processeur. La fonction CPU do_custom_call est responsable de la mise en file d'attente du travail sur le GPU. Ici, il lance un noyau CUDA, mais il pourrait aussi faire autre chose, comme appeler cuBLAS.

Les arguments et les résultats se trouvent également sur l'hôte, et le membre de données contient un pointeur vers la mémoire du périphérique (c'est-à-dire le GPU). Les tampons transmis au gestionnaire d'appels personnalisés ont la forme des tampons de l'appareil sous-jacent. L'appel personnalisé peut donc calculer les paramètres de lancement du noyau à partir de ces tampons.

Transmettre des tuples à des appels personnalisés

Prenons l'exemple d'un appel personnalisé.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Sur le CPU et le GPU, un tuple est représenté en mémoire sous la forme d'un tableau de pointeurs. Lorsque XLA appelle des appels personnalisés avec des arguments ou des résultats de tuples, il les aplatit et les transmet en tant qu'arguments ou résultats de tampon réguliers.

Sorties de tuples en tant que tampons temporaires

Les entrées de tuples pour les appels personnalisés sont pratiques, mais pas strictement nécessaires. Si nous ne prenions pas en charge les entrées de tuples pour les appels personnalisés, vous pourriez toujours décompresser les tuples à l'aide de get-tuple-element avant de les transmettre à l'appel personnalisé.

En revanche, les sorties de tuples vous permettent de faire des choses que vous ne pourriez pas faire autrement.

La raison évidente d'avoir des sorties de tuples est que les sorties de tuples sont la façon dont un appel personnalisé (ou toute autre opération XLA) renvoie plusieurs tableaux indépendants.

De manière moins évidente, une sortie de tuple est également un moyen de donner une mémoire temporaire à votre appel personnalisé. Oui, une sortie peut représenter une mémoire tampon temporaire. Considérez qu'un tampon de sortie a la propriété que l'opération peut y écrire et y lire après y avoir écrit. C'est exactement ce que vous attendez d'un tampon temporaire.

Dans l'exemple ci-dessus, supposons que nous voulions utiliser F32[1024] comme tampon temporaire. Nous écririons ensuite le HLO comme ci-dessus, et nous ne lirions tout simplement jamais l'index de tuple 1 de la sortie de l'appel personnalisé.