Appels personnalisés XLA

Ce document explique comment écrire et utiliser des appels personnalisés XLA à l'aide de la bibliothèque FFI XLA. Un appel personnalisé est un mécanisme permettant de décrire une "opération" externe dans HLO vers le compilateur XLA (au moment de la compilation), et XLA FFI est un mécanisme permettant enregistrer l'implémentation de ces opérations avec XLA (au moment de l'exécution). FFI signifie "Foreign Function Interface" (interface de fonction étrangère). Il s'agit d'un ensemble d'API C qui définissent une interface binaire (ABI) permettant à XLA d'appeler du code externe écrit dans d'autres langages de programmation. XLA fournit des liaisons en-tête uniquement pour les FFI XLA écrites en C++, ce qui masque tous les détails de bas niveau des API C sous-jacentes à l'utilisateur final.

Appels personnalisés JAX + XLA

Consultez la documentation JAX pour : exemples de bout en bout d'intégration des appels personnalisés et de XLA FFI avec JAX.

Liaison FFI XLA

La liaison FFI XLA est une spécification au moment de la compilation de la signature d'appel personnalisée : arguments d'appel personnalisés, attributs et types, et paramètres supplémentaires transmis via le contexte d'exécution (c'est-à-dire le flux GPU pour le backend GPU). La recherche de FFI XLA peut être liée à n'importe quel appelable C++ (pointeur de fonction, lambda, etc.) avec une signature operator() compatible. Le gestionnaire construit décode le frame d'appel FFI XLA (défini par l'API C stable), vérifie le type de tous les paramètres et transmet les résultats décodés au rappel défini par l'utilisateur.

La liaison FFI XLA s’appuie fortement sur la métaprogrammation de modèle pour pouvoir compilez le gestionnaire construit dans le code machine le plus efficace. Durée d'exécution sont de l'ordre de quelques nanosecondes pour chaque appel personnalisé. .

Les points de personnalisation FFI XLA implémentés en tant que spécialisations de modèle, et les utilisateurs peuvent définir comment décoder leurs types personnalisés, c'est-à-dire qu'il est possible de définir un décodage personnalisé pour les types enum class définis par l'utilisateur.

Retourner des erreurs à partir d'appels personnalisés

Les implémentations d'appels personnalisés doivent renvoyer la valeur xla::ffi::Error pour signaler un succès ou une erreur à l'environnement d'exécution XLA. Il est semblable à absl::Status et possède le même ensemble de codes d'erreur. Nous n'utilisons pas absl::Status, car il n'a pas d'ABI stable et il serait dangereux de le transmettre entre la bibliothèque d'appels personnalisée chargée dynamiquement et XLA elle-même.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Mettre en mémoire tampon les arguments et les résultats

XLA utilise le style de transmission de la destination pour les résultats: appels personnalisés (ou tout autre type XLA) d'ailleurs) n'allouez pas de mémoire pour les résultats, mais écrire dans les destinations transmises par l'environnement d'exécution XLA ; XLA utilise un tampon statique et alloue des tampons pour toutes les valeurs en fonction de leurs plages actives à la compilation.

Résultats transmis aux gestionnaires FFI encapsulés dans un modèle Result<T>, qui a une sémantique semblable à un pointeur : operator-> permet d'accéder au paramètre sous-jacent.

Les arguments et résultats AnyBuffer permettent d'accéder à des paramètres de tampon d'appel personnalisés de tout type de données. Ceci est utile lorsque l'appel personnalisé a une implémentation générique compatible avec plusieurs types de données. Par ailleurs, l'implémentation personnalisée des appels s'exécute en fonction du type de données. AnyBuffer donne accès aux données de la mémoire tampon le type, les dimensions et un pointeur vers le tampon lui-même.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Arguments et résultats de tampon restreints

Buffer permet d'ajouter des contraintes sur le type et le rang des données du tampon. Elles seront automatiquement vérifiées par le gestionnaire et renverront une erreur à l'environnement d'exécution XLA, si les arguments d'exécution ne correspondent pas à la signature du gestionnaire FFI.

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Arguments et résultats variables

Si le nombre d'arguments et le résultat peuvent être différents dans différentes instances de un appel personnalisé, vous pouvez les décoder au moment de l'exécution à l'aide de RemainingArgs, puis RemainingRets

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Les arguments et résultats variables peuvent être déclarés après des arguments et des résultats standards Toutefois, la liaison des arguments standards et des résultats après un variadique illégales.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Attributs

La FFI XLA prend en charge le décodage automatique des mlir::DictionaryAttr transmis en tant que custom_call backend_config en arguments de gestionnaire FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dans cet exemple, l'appel personnalisé comporte un seul argument de tampon et deux attributs. La FFI XLA peut les décoder automatiquement et les transmettre à l'appelable défini par l'utilisateur.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Attributs d'énumération définis par l'utilisateur

La FFI XLA peut décoder automatiquement les attributs MLIR intégrés en énumérations définies par l'utilisateur. La classe d'énumération doit avoir le même type intégral sous-jacent, et le décodage doit être explicitement enregistré auprès de la FFI XLA.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Lier tous les attributs d'appel personnalisés

Il est possible d'accéder à tous les attributs d'appel personnalisés sous la forme d'un dictionnaire et ne décodent que les attributs nécessaires au moment de l'exécution.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Attributs de struct définis par l'utilisateur

La FFI XLA peut décoder les attributs de dictionnaire en structures définies par l'utilisateur.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Dans l'exemple ci-dessus, range est un attribut mlir::DictionaryAttr. À la place, d'accès aux champs du dictionnaire par leur nom, il peut être automatiquement décodé un struct C++. Le décodage doit être explicitement enregistré avec XLA_FFI_REGISTER_STRUCT_ATTR_DECODING (en arrière-plan de la scène qu'elle définit une spécialisation de modèle dans l'espace de noms ::xla::ffi, la macro doit donc être ajoutée à l'espace de noms global).

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
                                             StructMember<int64_t>("i64"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Les attributs personnalisés peuvent être chargés depuis un dictionnaire, comme n'importe quel autre . Dans l'exemple ci-dessous, tous les attributs d'appel personnalisés sont décodés en tant que Dictionary, et un range peut être accessible par nom.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Créer un appel personnalisé sur le processeur

Vous pouvez créer une instruction HLO représentant un appel personnalisé via l'API cliente de XLA. Par exemple, le code suivant utilise un appel personnalisé pour calculer A[i] = B[i % 128]+ C[i] sur le processeur. (Bien sûr que vous pouvez, et vous devriez !) – effectuez cette opération avec une HLO standard.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Créer un appel personnalisé sur un GPU

L'enregistrement des appels personnalisés GPU avec FFI XLA est presque identique, le seul mais qu'avec le GPU, vous devez demander un flux de plate-forme sous-jacent (flux CUDA ou ROCM) pour pouvoir lancer le noyau sur l'appareil. Voici un exemple CUDA qui effectue le même calcul (A[i] = B[i % 128] + C[i]) que le code du processeur ci-dessus.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Notez tout d'abord que la fonction d'appel personnalisé GPU est toujours une fonction exécutée sur le processeur. La fonction de processeur do_custom_call est responsable de la mise en file d'attente des tâches. sur le GPU. Ici, il lance un noyau CUDA, mais il peut aussi faire autre chose, comme appeler cuBLAS.

Les arguments et les résultats se trouvent également sur l'hôte, et le membre de données contient un pointeur vers la mémoire de l'appareil (GPU, par exemple). Les tampons transmis au gestionnaire d'appels personnalisés ont la forme des tampons de l'appareil sous-jacent, de sorte que l'appel personnalisé puisse en calculer les paramètres de lancement du noyau.

Transmettre des tupels à des appels personnalisés

Prenons l'exemple de l'appel personnalisé suivant.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Sur le CPU et le GPU, un tuple est représenté en mémoire sous la forme d'un tableau de pointeurs. Lorsque XLA appelle des appels personnalisés avec des arguments ou des résultats de tuple, il les aplatit et les transmet en tant qu'arguments ou résultats de tampon standards.

Sorties Tuple en tant que tampons temporaires

Les entrées de tuple pour les appels personnalisés sont pratiques, mais elles ne sont pas strictement nécessaires. Si les entrées de tuple ne sont pas prises en charge pour les appels personnalisés, vous pouvez toujours décompresser les tuples à l'aide de get-tuple-element avant de les transmettre au modèle .

En revanche, les sorties des tuples vous permettent de faire des choses que vous ne pourriez pas faire autrement.

La raison évidente de disposer de sorties de tuple est que les sorties de tuple sont la manière (ou toute autre opération XLA) renvoie plusieurs tableaux indépendants.

Mais moins évident, une sortie de tuple est également un moyen d'attribuer à votre température d'appel personnalisée mémoire. Oui, une sortie peut représenter un tampon temporaire. Prenons l'exemple d'un tampon de sortie possède la propriété que l'opération peut y écrire, et qu'elle peut lire à partir de celle-ci dans lequel elles ont été écrites. C'est exactement ce que vous attendez d'un tampon temporaire.

Dans l'exemple ci-dessus, supposons que nous souhaitions utiliser F32[1024] comme tampon temporaire. Nous écririons ensuite le HLO comme ci-dessus, et nous ne lirions jamais l'indice de tuple 1 de la sortie de l'appel personnalisé.